Skip to content

Commit

Permalink
Implement Dataset, Dataloader and DataModule class and fix Supervised…
Browse files Browse the repository at this point in the history
…Solver
  • Loading branch information
FilippoOlivo committed Oct 22, 2024
1 parent 1818dc6 commit 25fe0fd
Show file tree
Hide file tree
Showing 30 changed files with 778 additions and 792 deletions.
6 changes: 4 additions & 2 deletions pina/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
"Plotter",
"Condition",
"SamplePointDataset",
"SamplePointLoader",
"PinaDataModule",
"PinaDataLoader"
]

from .meta import *
Expand All @@ -15,4 +16,5 @@
from .plotter import Plotter
from .condition.condition import Condition
from .data import SamplePointDataset
from .data import SamplePointLoader
from .data import PinaDataModule
from .data import PinaDataLoader
27 changes: 14 additions & 13 deletions pina/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,29 @@
from . import LabelTensor
from .utils import check_consistency, merge_tensors


class Collector:
def __init__(self, problem):
# creating a hook between collector and problem
self.problem = problem
self.problem = problem

# this variable is used to store the data in the form:
# {'[condition_name]' :
# {'input_points' : Tensor,
# '[equation/output_points/conditional_variables]': Tensor}
# }
# those variables are used for the dataloading
self._data_collections = {name : {} for name in self.problem.conditions}
self._data_collections = {name: {} for name in self.problem.conditions}

# variables used to check that all conditions are sampled
self._is_conditions_ready = {
name : False for name in self.problem.conditions}
name: False for name in self.problem.conditions}
self.full = False

@property
def full(self):
return all(self._is_conditions_ready.values())

@full.setter
def full(self, value):
check_consistency(value, bool)
Expand All @@ -37,7 +38,7 @@ def data_collections(self):
@property
def problem(self):
return self._problem

@problem.setter
def problem(self, value):
self._problem = value
Expand Down Expand Up @@ -76,14 +77,14 @@ def store_sample_domains(self, n, mode, variables, sample_locations):

# get the samples
samples = [
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (
set(pts.labels).issubset(sorted(self.problem.input_variables))
):
set(pts.labels).issubset(sorted(self.problem.input_variables))
):
pts = pts.sort_labels()
if sorted(pts.labels)==sorted(self.problem.input_variables):
if sorted(pts.labels) == sorted(self.problem.input_variables):
self._is_conditions_ready[loc] = True
values = [pts, condition.equation]
self.data_collections[loc] = dict(zip(keys, values))
Expand All @@ -97,7 +98,7 @@ def add_points(self, new_points_dict):
:param new_points_dict: Dictonary of input points (condition_name: LabelTensor)
:raises RuntimeError: if at least one condition is not already sampled
"""
for k,v in new_points_dict.items():
for k, v in new_points_dict.items():
if not self._is_conditions_ready[k]:
raise RuntimeError('Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)
2 changes: 1 addition & 1 deletion pina/condition/condition_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class ConditionInterface(metaclass=ABCMeta):

condition_types = ['physics', 'supervised', 'unsupervised']

def __init__(self, *args, **wargs):
def __init__(self, *args, **kwargs):
self._condition_type = None
self._problem = None

Expand Down
4 changes: 2 additions & 2 deletions pina/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def __init__(self, input_points, conditional_variables=None):
super().__init__()
self.input_points = input_points
self.conditional_variables = conditional_variables
self.condition_type = 'unsupervised'
self._condition_type = 'unsupervised'

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'conditional_variables'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
DataConditionInterface.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'):
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
4 changes: 2 additions & 2 deletions pina/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, domain, equation):
super().__init__()
self.domain = domain
self.equation = equation
self.condition_type = 'physics'
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'domain':
Expand All @@ -29,5 +29,5 @@ def __setattr__(self, key, value):
elif key == 'equation':
check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'):
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
2 changes: 1 addition & 1 deletion pina/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ def __setattr__(self, key, value):
elif key == 'equation':
check_consistency(value, (EquationInterface))
InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'):
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
4 changes: 2 additions & 2 deletions pina/condition/input_output_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def __init__(self, input_points, output_points):
super().__init__()
self.input_points = input_points
self.output_points = output_points
self.condition_type = ['supervised', 'physics']
self._condition_type = ['supervised', 'physics']

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
InputOutputPointsCondition.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'):
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
19 changes: 16 additions & 3 deletions pina/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
"""
Import data classes
"""
__all__ = [
'PinaDataLoader',
'SupervisedDataset',
'SamplePointDataset',
'UnsupervisedDataset',
'Batch',
'PinaDataModule',
'BaseDataset'
]

from .pina_dataloader import SamplePointLoader
from .data_dataset import DataPointDataset
from .pina_dataloader import PinaDataLoader
from .supervised_dataset import SupervisedDataset
from .sample_dataset import SamplePointDataset
from .pina_batch import Batch
from .unsupervised_dataset import UnsupervisedDataset
from .pina_batch import Batch
from .data_module import PinaDataModule
from .base_dataset import BaseDataset
107 changes: 107 additions & 0 deletions pina/data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Basic data module implementation
"""
from torch.utils.data import Dataset
import torch
from ..label_tensor import LabelTensor


class BaseDataset(Dataset):
"""
BaseDataset class, which handle initialization and data retrieval
:var condition_indices: List of indices
:var device: torch.device
:var condition_names: dict of condition index and corresponding name
"""

def __new__(cls, problem, device):
"""
Ensure correct definition of __slots__ before initialization
:param AbstractProblem problem: The formulation of the problem.
:param torch.device device: The device on which the
dataset will be loaded.
"""
if cls is BaseDataset:
raise TypeError('BaseDataset cannot be instantiated directly. Use a subclass.')
if not hasattr(cls, '__slots__'):
raise TypeError('Something is wrong, __slots__ must be defined in subclasses.')
return super().__new__(cls)

def __init__(self, problem, device):
""""
Initialize the object based on __slots__
:param AbstractProblem problem: The formulation of the problem.
:param torch.device device: The device on which the
dataset will be loaded.
"""
super().__init__()

self.condition_names = {}
collector = problem.collector
for slot in self.__slots__:
setattr(self, slot, [])

idx = 0
for name, data in collector.data_collections.items():
keys = []
for k, v in data.items():
if isinstance(v, LabelTensor):
keys.append(k)
if sorted(self.__slots__) == sorted(keys):

for slot in self.__slots__:
current_list = getattr(self, slot)
current_list.append(data[slot])
self.condition_names[idx] = name
idx += 1

if len(getattr(self, self.__slots__[0])) > 0:
input_list = getattr(self, self.__slots__[0])
self.condition_indices = torch.cat(
[
torch.tensor([i] * len(input_list[i]), dtype=torch.uint8)
for i in range(len(self.condition_names))
],
dim=0,
)
for slot in self.__slots__:
current_attribute = getattr(self, slot)
setattr(self, slot, LabelTensor.vstack(current_attribute))
else:
self.condition_indices = torch.tensor([], dtype=torch.uint8)
for slot in self.__slots__:
setattr(self, slot, torch.tensor([]))

self.device = device

def __len__(self):
return len(getattr(self, self.__slots__[0]))

def __getattribute__(self, item):
attribute = super().__getattribute__(item)
if isinstance(attribute, LabelTensor) and attribute.dtype == torch.float32:
attribute = attribute.to(device=self.device).requires_grad_()
return attribute

def __getitem__(self, idx):
if isinstance(idx, str):
return getattr(self, idx).to(self.device)

if isinstance(idx, slice):
to_return_list = []
for i in self.__slots__:
to_return_list.append(getattr(self, i)[[idx]].to(self.device))
return to_return_list

if isinstance(idx, (tuple, list)):
if (len(idx) == 2 and isinstance(idx[0], str)
and isinstance(idx[1], (list, slice))):
tensor = getattr(self, idx[0])
return tensor[[idx[1]]].to(self.device)
if all(isinstance(x, int) for x in idx):
to_return_list = []
for i in self.__slots__:
to_return_list.append(getattr(self, i)[[idx]].to(self.device))
return to_return_list

raise ValueError(f'Invalid index {idx}')
41 changes: 0 additions & 41 deletions pina/data/data_dataset.py

This file was deleted.

Loading

0 comments on commit 25fe0fd

Please sign in to comment.