Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of DataLoader and DataModule #383

Merged
merged 15 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pina/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__all__ = [
"Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset",
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph'
"Trainer", "LabelTensor", "Plotter", "Condition",
"PinaDataModule", 'TorchOptimizer', 'Graph',
]

from .meta import *
Expand All @@ -9,9 +9,9 @@
from .trainer import Trainer
from .plotter import Plotter
from .condition.condition import Condition
from .data import SamplePointDataset

from .data import PinaDataModule
from .data import PinaDataLoader

from .optim import TorchOptimizer
from .optim import TorchScheduler
from .graph import Graph
18 changes: 12 additions & 6 deletions pina/collector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import LabelTensor
from .utils import check_consistency, merge_tensors


Expand Down Expand Up @@ -66,9 +67,12 @@ def store_sample_domains(self, n, mode, variables, sample_locations):
for loc in sample_locations:
# get condition
condition = self.problem.conditions[loc]
condition_domain = condition.domain
if isinstance(condition_domain, str):
condition_domain = self.problem.domains[condition_domain]
keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data
if (not self._is_conditions_ready[loc]):
if not self._is_conditions_ready[loc]:
# if it is the first time we sample
if not self.data_collections[loc]:
already_sampled = []
Expand All @@ -84,10 +88,11 @@ def store_sample_domains(self, n, mode, variables, sample_locations):

# get the samples
samples = [
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
condition_domain.sample(n=n, mode=mode,
variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (set(pts.labels).issubset(sorted(self.problem.input_variables))):
if set(pts.labels).issubset(sorted(self.problem.input_variables)):
pts = pts.sort_labels()
if sorted(pts.labels) == sorted(self.problem.input_variables):
self._is_conditions_ready[loc] = True
Expand All @@ -110,5 +115,6 @@ def add_points(self, new_points_dict):
if not self._is_conditions_ready[k]:
raise RuntimeError(
'Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k][
'input_points'].vstack(v)
self.data_collections[k]['input_points'] = LabelTensor.vstack(
[self.data_collections[k][
'input_points'], v])
3 changes: 1 addition & 2 deletions pina/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ class DataConditionInterface(ConditionInterface):

def __init__(self, input_points, conditional_variables=None):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.conditional_variables = conditional_variables
self._condition_type = 'unsupervised'

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'conditional_variables'):
Expand Down
5 changes: 2 additions & 3 deletions pina/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,15 @@ class DomainEquationCondition(ConditionInterface):

def __init__(self, domain, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.domain = domain
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'domain':
check_consistency(value, (DomainInterface))
check_consistency(value, (DomainInterface, str))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key == 'equation':
check_consistency(value, (EquationInterface))
Expand Down
3 changes: 1 addition & 2 deletions pina/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ class InputPointsEquationCondition(ConditionInterface):

def __init__(self, input_points, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.equation = equation
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'input_points':
Expand Down
6 changes: 3 additions & 3 deletions pina/condition/input_output_condition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch_geometric

from .condition_interface import ConditionInterface
from ..label_tensor import LabelTensor
Expand All @@ -16,16 +17,15 @@ class InputOutputPointsCondition(ConditionInterface):

def __init__(self, input_points, output_points):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.output_points = output_points
self._condition_type = ['supervised', 'physics']

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
check_consistency(value, (LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data))
InputOutputPointsCondition.__dict__[key].__set__(self, value)
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
13 changes: 5 additions & 8 deletions pina/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
Import data classes
"""
__all__ = [
'PinaDataLoader', 'SupervisedDataset', 'SamplePointDataset',
'UnsupervisedDataset', 'Batch', 'PinaDataModule', 'BaseDataset'
'PinaDataModule',
'PinaDataset'
]

from .pina_dataloader import PinaDataLoader
from .supervised_dataset import SupervisedDataset
from .sample_dataset import SamplePointDataset
from .unsupervised_dataset import UnsupervisedDataset
from .pina_batch import Batch


from .data_module import PinaDataModule
from .base_dataset import BaseDataset
from .dataset import PinaDataset
157 changes: 0 additions & 157 deletions pina/data/base_dataset.py

This file was deleted.

Loading
Loading