Skip to content

Commit

Permalink
Reimplementation of data management classes, fix bugs and improve eff…
Browse files Browse the repository at this point in the history
…iciency of LabelTensor
  • Loading branch information
FilippoOlivo committed Nov 25, 2024
1 parent e851c33 commit 63bc3a7
Show file tree
Hide file tree
Showing 28 changed files with 682 additions and 997 deletions.
10 changes: 5 additions & 5 deletions pina/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
__all__ = [
"Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset",
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph'
"Trainer", "LabelTensor", "Plotter", "Condition",
"PinaDataModule", 'TorchOptimizer', 'Graph', 'LabelParameter'
]

from .meta import *
from .label_tensor import LabelTensor
from .label_tensor import LabelTensor, LabelParameter
from .solvers.solver import SolverInterface
from .trainer import Trainer
from .plotter import Plotter
from .condition.condition import Condition
from .data import SamplePointDataset

from .data import PinaDataModule
from .data import PinaDataLoader

from .optim import TorchOptimizer
from .optim import TorchScheduler
from .graph import Graph
18 changes: 12 additions & 6 deletions pina/collector.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import LabelTensor
from .utils import check_consistency, merge_tensors


Expand Down Expand Up @@ -66,9 +67,12 @@ def store_sample_domains(self, n, mode, variables, sample_locations):
for loc in sample_locations:
# get condition
condition = self.problem.conditions[loc]
condition_domain = condition.domain
if isinstance(condition_domain, str):
condition_domain = self.problem.domains[condition_domain]
keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data
if (not self._is_conditions_ready[loc]):
if not self._is_conditions_ready[loc]:
# if it is the first time we sample
if not self.data_collections[loc]:
already_sampled = []
Expand All @@ -84,10 +88,11 @@ def store_sample_domains(self, n, mode, variables, sample_locations):

# get the samples
samples = [
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
condition_domain.sample(n=n, mode=mode,
variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (set(pts.labels).issubset(sorted(self.problem.input_variables))):
if set(pts.labels).issubset(sorted(self.problem.input_variables)):
pts = pts.sort_labels()
if sorted(pts.labels) == sorted(self.problem.input_variables):
self._is_conditions_ready[loc] = True
Expand All @@ -110,5 +115,6 @@ def add_points(self, new_points_dict):
if not self._is_conditions_ready[k]:
raise RuntimeError(
'Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k][
'input_points'].vstack(v)
self.data_collections[k]['input_points'] = LabelTensor.vstack(
[self.data_collections[k][
'input_points'], v])
2 changes: 1 addition & 1 deletion pina/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DataConditionInterface(ConditionInterface):

def __init__(self, input_points, conditional_variables=None):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
Expand Down
4 changes: 2 additions & 2 deletions pina/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class DomainEquationCondition(ConditionInterface):
condition_type = ['physics']
def __init__(self, domain, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.domain = domain
self.equation = equation

def __setattr__(self, key, value):
if key == 'domain':
check_consistency(value, (DomainInterface))
check_consistency(value, (DomainInterface, str))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key == 'equation':
check_consistency(value, (EquationInterface))
Expand Down
2 changes: 1 addition & 1 deletion pina/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class InputPointsEquationCondition(ConditionInterface):
condition_type = ['physics']
def __init__(self, input_points, equation):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
Expand Down
5 changes: 3 additions & 2 deletions pina/condition/input_output_condition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch_geometric

from .condition_interface import ConditionInterface
from ..label_tensor import LabelTensor
Expand All @@ -16,15 +17,15 @@ class InputOutputPointsCondition(ConditionInterface):
condition_type = ['supervised']
def __init__(self, input_points, output_points):
"""
TODO
TODO : add docstring
"""
super().__init__()
self.input_points = input_points
self.output_points = output_points

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
check_consistency(value, (LabelTensor, Graph, torch.Tensor, torch_geometric.data.Data))
InputOutputPointsCondition.__dict__[key].__set__(self, value)
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
13 changes: 5 additions & 8 deletions pina/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
Import data classes
"""
__all__ = [
'PinaDataLoader', 'SupervisedDataset', 'SamplePointDataset',
'UnsupervisedDataset', 'Batch', 'PinaDataModule', 'BaseDataset'
'PinaDataModule',
'PinaDataset'
]

from .pina_dataloader import PinaDataLoader
from .supervised_dataset import SupervisedDataset
from .sample_dataset import SamplePointDataset
from .unsupervised_dataset import UnsupervisedDataset
from .pina_batch import Batch


from .data_module import PinaDataModule
from .base_dataset import BaseDataset
from .dataset import PinaDataset
156 changes: 0 additions & 156 deletions pina/data/base_dataset.py

This file was deleted.

Loading

0 comments on commit 63bc3a7

Please sign in to comment.