Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver #368

Merged
merged 5 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions pina/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
__all__ = [
"PINN",
"Trainer",
"LabelTensor",
"Plotter",
"Condition",
"SamplePointDataset",
"SamplePointLoader",
"PINN", "Trainer", "LabelTensor", "Plotter", "Condition",
"SamplePointDataset", "PinaDataModule", "PinaDataLoader"
]

from .meta import *
Expand All @@ -15,4 +10,5 @@
from .plotter import Plotter
from .condition.condition import Condition
from .data import SamplePointDataset
from .data import SamplePointLoader
from .data import PinaDataModule
from .data import PinaDataLoader
43 changes: 26 additions & 17 deletions pina/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,29 @@
from . import LabelTensor
from .utils import check_consistency, merge_tensors


class Collector:
def __init__(self, problem):
# creating a hook between collector and problem
self.problem = problem
self.problem = problem

# this variable is used to store the data in the form:
# {'[condition_name]' :
# {'input_points' : Tensor,
# '[equation/output_points/conditional_variables]': Tensor}
# }
# those variables are used for the dataloading
self._data_collections = {name : {} for name in self.problem.conditions}
self._data_collections = {name: {} for name in self.problem.conditions}

# variables used to check that all conditions are sampled
self._is_conditions_ready = {
name : False for name in self.problem.conditions}
name: False for name in self.problem.conditions}
self.full = False

@property
def full(self):
return all(self._is_conditions_ready.values())

@full.setter
def full(self, value):
check_consistency(value, bool)
Expand All @@ -37,7 +38,7 @@ def data_collections(self):
@property
def problem(self):
return self._problem

@problem.setter
def problem(self, value):
self._problem = value
Expand All @@ -47,7 +48,8 @@ def store_fixed_data(self):
for condition_name, condition in self.problem.conditions.items():
# if the condition is not ready and domain is not attribute
# of condition, we get and store the data
if (not self._is_conditions_ready[condition_name]) and (not hasattr(condition, "domain")):
if (not self._is_conditions_ready[condition_name]) and (
not hasattr(condition, "domain")):
# get data
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
Expand All @@ -68,27 +70,32 @@ def store_sample_domains(self, n, mode, variables, sample_locations):
already_sampled = []
# if we have sampled the condition but not all variables
else:
already_sampled = [self.data_collections[loc]['input_points']]
already_sampled = [
self.data_collections[loc]['input_points']]
# if the condition is ready but we want to sample again
else:
self._is_conditions_ready[loc] = False
already_sampled = []

# get the samples
samples = [
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
condition.domain.sample(n=n, mode=mode,
variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (
set(pts.labels).issubset(sorted(self.problem.input_variables))
):
set(pts.labels).issubset(
sorted(self.problem.input_variables))
):
pts = pts.sort_labels()
if sorted(pts.labels)==sorted(self.problem.input_variables):
if sorted(pts.labels) == sorted(self.problem.input_variables):
self._is_conditions_ready[loc] = True
values = [pts, condition.equation]
self.data_collections[loc] = dict(zip(keys, values))
else:
raise RuntimeError('Try to sample variables which are not in problem defined in the problem')
raise RuntimeError(
'Try to sample variables which are not in problem defined '
'in the problem')

def add_points(self, new_points_dict):
"""
Expand All @@ -97,7 +104,9 @@ def add_points(self, new_points_dict):
:param new_points_dict: Dictonary of input points (condition_name: LabelTensor)
:raises RuntimeError: if at least one condition is not already sampled
"""
for k,v in new_points_dict.items():
for k, v in new_points_dict.items():
if not self._is_conditions_ready[k]:
raise RuntimeError('Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v)
raise RuntimeError(
'Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k][
'input_points'].vstack(v)
23 changes: 12 additions & 11 deletions pina/condition/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .input_output_condition import InputOutputPointsCondition
from .data_condition import DataConditionInterface


class Condition:
"""
The class ``Condition`` is used to represent the constraints (physical
Expand Down Expand Up @@ -38,23 +39,23 @@ class Condition:
"""

__slots__ = list(
set(
InputOutputPointsCondition.__slots__ +
InputPointsEquationCondition.__slots__ +
DomainEquationCondition.__slots__ +
DataConditionInterface.__slots__
)
)
set(
InputOutputPointsCondition.__slots__ +
InputPointsEquationCondition.__slots__ +
DomainEquationCondition.__slots__ +
DataConditionInterface.__slots__
)
)

def __new__(cls, *args, **kwargs):

if len(args) != 0:
raise ValueError(
"Condition takes only the following keyword "
f"arguments: {Condition.__slots__}."
)
sorted_keys = sorted(kwargs.keys())

sorted_keys = sorted(kwargs.keys())
if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
return InputOutputPointsCondition(**kwargs)
elif sorted_keys == sorted(InputPointsEquationCondition.__slots__):
Expand All @@ -66,4 +67,4 @@ def __new__(cls, *args, **kwargs):
elif sorted_keys == DataConditionInterface.__slots__[0]:
return DataConditionInterface(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
2 changes: 1 addition & 1 deletion pina/condition/condition_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ class ConditionInterface(metaclass=ABCMeta):

condition_types = ['physics', 'supervised', 'unsupervised']

def __init__(self, *args, **wargs):
def __init__(self, *args, **kwargs):
self._condition_type = None
self._problem = None

Expand Down
4 changes: 2 additions & 2 deletions pina/condition/data_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def __init__(self, input_points, conditional_variables=None):
super().__init__()
self.input_points = input_points
self.conditional_variables = conditional_variables
self.condition_type = 'unsupervised'
self._condition_type = 'unsupervised'

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'conditional_variables'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
DataConditionInterface.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'):
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
4 changes: 2 additions & 2 deletions pina/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, domain, equation):
super().__init__()
self.domain = domain
self.equation = equation
self.condition_type = 'physics'
self._condition_type = 'physics'

def __setattr__(self, key, value):
if key == 'domain':
Expand All @@ -29,5 +29,5 @@ def __setattr__(self, key, value):
elif key == 'equation':
check_consistency(value, (EquationInterface))
DomainEquationCondition.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'):
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
2 changes: 1 addition & 1 deletion pina/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ def __setattr__(self, key, value):
elif key == 'equation':
check_consistency(value, (EquationInterface))
InputPointsEquationCondition.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'):
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
4 changes: 2 additions & 2 deletions pina/condition/input_output_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def __init__(self, input_points, output_points):
super().__init__()
self.input_points = input_points
self.output_points = output_points
self.condition_type = ['supervised', 'physics']
self._condition_type = ['supervised', 'physics']

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'output_points'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
InputOutputPointsCondition.__dict__[key].__set__(self, value)
elif key in ('problem', 'condition_type'):
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
14 changes: 11 additions & 3 deletions pina/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
"""
Import data classes
"""
__all__ = [
'PinaDataLoader', 'SupervisedDataset', 'SamplePointDataset',
'UnsupervisedDataset', 'Batch', 'PinaDataModule', 'BaseDataset'
]

from .pina_dataloader import SamplePointLoader
from .data_dataset import DataPointDataset
from .pina_dataloader import PinaDataLoader
from .supervised_dataset import SupervisedDataset
from .sample_dataset import SamplePointDataset
from .pina_batch import Batch
from .unsupervised_dataset import UnsupervisedDataset
from .pina_batch import Batch
from .data_module import PinaDataModule
from .base_dataset import BaseDataset
116 changes: 116 additions & 0 deletions pina/data/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Basic data module implementation
"""
from torch.utils.data import Dataset
import torch
from ..label_tensor import LabelTensor
from ..graph import Graph


class BaseDataset(Dataset):
"""
BaseDataset class, which handle initialization and data retrieval
:var condition_indices: List of indices
:var device: torch.device
:var condition_names: dict of condition index and corresponding name
"""

def __new__(cls, problem, device):
"""
Ensure correct definition of __slots__ before initialization
:param AbstractProblem problem: The formulation of the problem.
:param torch.device device: The device on which the
dataset will be loaded.
"""
if cls is BaseDataset:
raise TypeError(
'BaseDataset cannot be instantiated directly. Use a subclass.')
if not hasattr(cls, '__slots__'):
raise TypeError(
'Something is wrong, __slots__ must be defined in subclasses.')
return object.__new__(cls)

def __init__(self, problem, device):
""""
Initialize the object based on __slots__
:param AbstractProblem problem: The formulation of the problem.
:param torch.device device: The device on which the
dataset will be loaded.
"""
super().__init__()

self.condition_names = {}
collector = problem.collector
for slot in self.__slots__:
setattr(self, slot, [])
num_el_per_condition = []
idx = 0
for name, data in collector.data_collections.items():
keys = list(data.keys())
current_cond_num_el = None
if sorted(self.__slots__) == sorted(keys):
for slot in self.__slots__:
slot_data = data[slot]
if isinstance(slot_data, (LabelTensor, torch.Tensor,
Graph)):
if current_cond_num_el is None:
current_cond_num_el = len(slot_data)
elif current_cond_num_el != len(slot_data):
raise ValueError('Different number of conditions')
current_list = getattr(self, slot)
current_list += [data[slot]] if not (
isinstance(data[slot], list)) else data[slot]
num_el_per_condition.append(current_cond_num_el)
self.condition_names[idx] = name
idx += 1
if num_el_per_condition:
self.condition_indices = torch.cat(
[
torch.tensor([i] * num_el_per_condition[i],
dtype=torch.uint8)
for i in range(len(num_el_per_condition))
],
dim=0,
)
for slot in self.__slots__:
current_attribute = getattr(self, slot)
if all(isinstance(a, LabelTensor) for a in current_attribute):
setattr(self, slot, LabelTensor.vstack(current_attribute))
else:
self.condition_indices = torch.tensor([], dtype=torch.uint8)
for slot in self.__slots__:
setattr(self, slot, torch.tensor([]))
self.device = device

def __len__(self):
return len(getattr(self, self.__slots__[0]))

def __getattribute__(self, item):
attribute = super().__getattribute__(item)
if isinstance(attribute,
LabelTensor) and attribute.dtype == torch.float32:
attribute = attribute.to(device=self.device).requires_grad_()
return attribute

def __getitem__(self, idx):
if isinstance(idx, str):
return getattr(self, idx).to(self.device)
if isinstance(idx, slice):
to_return_list = []
for i in self.__slots__:
to_return_list.append(getattr(self, i)[idx].to(self.device))
return to_return_list

if isinstance(idx, (tuple, list)):
if (len(idx) == 2 and isinstance(idx[0], str)
and isinstance(idx[1], (list, slice))):
tensor = getattr(self, idx[0])
return tensor[[idx[1]]].to(self.device)
if all(isinstance(x, int) for x in idx):
to_return_list = []
for i in self.__slots__:
to_return_list.append(
getattr(self, i)[[idx]].to(self.device))
return to_return_list

raise ValueError(f'Invalid index {idx}')
Loading