From 25fe0fd58c41c6f85edf85f9bf5fe803af3da486 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Wed, 16 Oct 2024 11:24:37 +0200 Subject: [PATCH] Implement Dataset, Dataloader and DataModule class and fix SupervisedSolver --- pina/__init__.py | 6 +- pina/collector.py | 27 +- pina/condition/condition_interface.py | 2 +- pina/condition/data_condition.py | 4 +- pina/condition/domain_equation_condition.py | 4 +- pina/condition/input_equation_condition.py | 2 +- pina/condition/input_output_condition.py | 4 +- pina/data/__init__.py | 19 +- pina/data/base_dataset.py | 107 ++++++++ pina/data/data_dataset.py | 41 --- pina/data/data_module.py | 172 +++++++++++++ pina/data/pina_batch.py | 57 ++--- pina/data/pina_dataloader.py | 220 +++------------- pina/data/pina_subset.py | 21 ++ pina/data/sample_dataset.py | 49 +--- pina/data/supervised_dataset.py | 12 + pina/data/unsupervised_dataset.py | 13 + pina/domain/cartesian.py | 2 +- pina/domain/ellipsoid.py | 3 +- pina/domain/operation_interface.py | 2 +- pina/domain/simplex.py | 4 +- pina/domain/union_domain.py | 6 +- pina/label_tensor.py | 32 +-- pina/operators.py | 2 +- pina/solvers/pinns/basepinn.py | 16 +- pina/solvers/solver.py | 255 +++++-------------- pina/solvers/supervised.py | 56 ++-- pina/trainer.py | 24 +- tests/test_dataset.py | 170 ++++++++----- tests/test_solvers/test_supervised_solver.py | 238 ++++++++--------- 30 files changed, 778 insertions(+), 792 deletions(-) create mode 100644 pina/data/base_dataset.py delete mode 100644 pina/data/data_dataset.py create mode 100644 pina/data/data_module.py create mode 100644 pina/data/pina_subset.py create mode 100644 pina/data/supervised_dataset.py create mode 100644 pina/data/unsupervised_dataset.py diff --git a/pina/__init__.py b/pina/__init__.py index 0fe93752..d110d284 100644 --- a/pina/__init__.py +++ b/pina/__init__.py @@ -5,7 +5,8 @@ "Plotter", "Condition", "SamplePointDataset", - "SamplePointLoader", + "PinaDataModule", + "PinaDataLoader" ] from .meta import * @@ -15,4 +16,5 @@ from .plotter import Plotter from .condition.condition import Condition from .data import SamplePointDataset -from .data import SamplePointLoader +from .data import PinaDataModule +from .data import PinaDataLoader \ No newline at end of file diff --git a/pina/collector.py b/pina/collector.py index f44c222a..f9ef194d 100644 --- a/pina/collector.py +++ b/pina/collector.py @@ -3,10 +3,11 @@ from . import LabelTensor from .utils import check_consistency, merge_tensors + class Collector: def __init__(self, problem): # creating a hook between collector and problem - self.problem = problem + self.problem = problem # this variable is used to store the data in the form: # {'[condition_name]' : @@ -14,17 +15,17 @@ def __init__(self, problem): # '[equation/output_points/conditional_variables]': Tensor} # } # those variables are used for the dataloading - self._data_collections = {name : {} for name in self.problem.conditions} + self._data_collections = {name: {} for name in self.problem.conditions} # variables used to check that all conditions are sampled self._is_conditions_ready = { - name : False for name in self.problem.conditions} + name: False for name in self.problem.conditions} self.full = False - + @property def full(self): return all(self._is_conditions_ready.values()) - + @full.setter def full(self, value): check_consistency(value, bool) @@ -37,7 +38,7 @@ def data_collections(self): @property def problem(self): return self._problem - + @problem.setter def problem(self, value): self._problem = value @@ -76,14 +77,14 @@ def store_sample_domains(self, n, mode, variables, sample_locations): # get the samples samples = [ - condition.domain.sample(n=n, mode=mode, variables=variables) - ] + already_sampled + condition.domain.sample(n=n, mode=mode, variables=variables) + ] + already_sampled pts = merge_tensors(samples) if ( - set(pts.labels).issubset(sorted(self.problem.input_variables)) - ): + set(pts.labels).issubset(sorted(self.problem.input_variables)) + ): pts = pts.sort_labels() - if sorted(pts.labels)==sorted(self.problem.input_variables): + if sorted(pts.labels) == sorted(self.problem.input_variables): self._is_conditions_ready[loc] = True values = [pts, condition.equation] self.data_collections[loc] = dict(zip(keys, values)) @@ -97,7 +98,7 @@ def add_points(self, new_points_dict): :param new_points_dict: Dictonary of input points (condition_name: LabelTensor) :raises RuntimeError: if at least one condition is not already sampled """ - for k,v in new_points_dict.items(): + for k, v in new_points_dict.items(): if not self._is_conditions_ready[k]: raise RuntimeError('Cannot add points on a non sampled condition') - self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v) \ No newline at end of file + self.data_collections[k]['input_points'] = self.data_collections[k]['input_points'].vstack(v) diff --git a/pina/condition/condition_interface.py b/pina/condition/condition_interface.py index 52699b66..808c06af 100644 --- a/pina/condition/condition_interface.py +++ b/pina/condition/condition_interface.py @@ -5,7 +5,7 @@ class ConditionInterface(metaclass=ABCMeta): condition_types = ['physics', 'supervised', 'unsupervised'] - def __init__(self, *args, **wargs): + def __init__(self, *args, **kwargs): self._condition_type = None self._problem = None diff --git a/pina/condition/data_condition.py b/pina/condition/data_condition.py index 90d248b6..3bcd4be6 100644 --- a/pina/condition/data_condition.py +++ b/pina/condition/data_condition.py @@ -22,11 +22,11 @@ def __init__(self, input_points, conditional_variables=None): super().__init__() self.input_points = input_points self.conditional_variables = conditional_variables - self.condition_type = 'unsupervised' + self._condition_type = 'unsupervised' def __setattr__(self, key, value): if (key == 'input_points') or (key == 'conditional_variables'): check_consistency(value, (LabelTensor, Graph, torch.Tensor)) DataConditionInterface.__dict__[key].__set__(self, value) - elif key in ('problem', 'condition_type'): + elif key in ('_problem', '_condition_type'): super().__setattr__(key, value) \ No newline at end of file diff --git a/pina/condition/domain_equation_condition.py b/pina/condition/domain_equation_condition.py index ce4c7d3f..28315655 100644 --- a/pina/condition/domain_equation_condition.py +++ b/pina/condition/domain_equation_condition.py @@ -20,7 +20,7 @@ def __init__(self, domain, equation): super().__init__() self.domain = domain self.equation = equation - self.condition_type = 'physics' + self._condition_type = 'physics' def __setattr__(self, key, value): if key == 'domain': @@ -29,5 +29,5 @@ def __setattr__(self, key, value): elif key == 'equation': check_consistency(value, (EquationInterface)) DomainEquationCondition.__dict__[key].__set__(self, value) - elif key in ('problem', 'condition_type'): + elif key in ('_problem', '_condition_type'): super().__setattr__(key, value) \ No newline at end of file diff --git a/pina/condition/input_equation_condition.py b/pina/condition/input_equation_condition.py index ac47fa2c..0d34dfc9 100644 --- a/pina/condition/input_equation_condition.py +++ b/pina/condition/input_equation_condition.py @@ -30,5 +30,5 @@ def __setattr__(self, key, value): elif key == 'equation': check_consistency(value, (EquationInterface)) InputPointsEquationCondition.__dict__[key].__set__(self, value) - elif key in ('problem', 'condition_type'): + elif key in ('_problem', '_condition_type'): super().__setattr__(key, value) \ No newline at end of file diff --git a/pina/condition/input_output_condition.py b/pina/condition/input_output_condition.py index f8fd46e8..8a17495d 100644 --- a/pina/condition/input_output_condition.py +++ b/pina/condition/input_output_condition.py @@ -21,11 +21,11 @@ def __init__(self, input_points, output_points): super().__init__() self.input_points = input_points self.output_points = output_points - self.condition_type = ['supervised', 'physics'] + self._condition_type = ['supervised', 'physics'] def __setattr__(self, key, value): if (key == 'input_points') or (key == 'output_points'): check_consistency(value, (LabelTensor, Graph, torch.Tensor)) InputOutputPointsCondition.__dict__[key].__set__(self, value) - elif key in ('problem', 'condition_type'): + elif key in ('_problem', '_condition_type'): super().__setattr__(key, value) diff --git a/pina/data/__init__.py b/pina/data/__init__.py index fba19b92..0a1b5905 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -1,7 +1,20 @@ +""" +Import data classes +""" __all__ = [ + 'PinaDataLoader', + 'SupervisedDataset', + 'SamplePointDataset', + 'UnsupervisedDataset', + 'Batch', + 'PinaDataModule', + 'BaseDataset' ] -from .pina_dataloader import SamplePointLoader -from .data_dataset import DataPointDataset +from .pina_dataloader import PinaDataLoader +from .supervised_dataset import SupervisedDataset from .sample_dataset import SamplePointDataset -from .pina_batch import Batch \ No newline at end of file +from .unsupervised_dataset import UnsupervisedDataset +from .pina_batch import Batch +from .data_module import PinaDataModule +from .base_dataset import BaseDataset diff --git a/pina/data/base_dataset.py b/pina/data/base_dataset.py new file mode 100644 index 00000000..f095afa0 --- /dev/null +++ b/pina/data/base_dataset.py @@ -0,0 +1,107 @@ +""" +Basic data module implementation +""" +from torch.utils.data import Dataset +import torch +from ..label_tensor import LabelTensor + + +class BaseDataset(Dataset): + """ + BaseDataset class, which handle initialization and data retrieval + :var condition_indices: List of indices + :var device: torch.device + :var condition_names: dict of condition index and corresponding name + """ + + def __new__(cls, problem, device): + """ + Ensure correct definition of __slots__ before initialization + :param AbstractProblem problem: The formulation of the problem. + :param torch.device device: The device on which the + dataset will be loaded. + """ + if cls is BaseDataset: + raise TypeError('BaseDataset cannot be instantiated directly. Use a subclass.') + if not hasattr(cls, '__slots__'): + raise TypeError('Something is wrong, __slots__ must be defined in subclasses.') + return super().__new__(cls) + + def __init__(self, problem, device): + """" + Initialize the object based on __slots__ + :param AbstractProblem problem: The formulation of the problem. + :param torch.device device: The device on which the + dataset will be loaded. + """ + super().__init__() + + self.condition_names = {} + collector = problem.collector + for slot in self.__slots__: + setattr(self, slot, []) + + idx = 0 + for name, data in collector.data_collections.items(): + keys = [] + for k, v in data.items(): + if isinstance(v, LabelTensor): + keys.append(k) + if sorted(self.__slots__) == sorted(keys): + + for slot in self.__slots__: + current_list = getattr(self, slot) + current_list.append(data[slot]) + self.condition_names[idx] = name + idx += 1 + + if len(getattr(self, self.__slots__[0])) > 0: + input_list = getattr(self, self.__slots__[0]) + self.condition_indices = torch.cat( + [ + torch.tensor([i] * len(input_list[i]), dtype=torch.uint8) + for i in range(len(self.condition_names)) + ], + dim=0, + ) + for slot in self.__slots__: + current_attribute = getattr(self, slot) + setattr(self, slot, LabelTensor.vstack(current_attribute)) + else: + self.condition_indices = torch.tensor([], dtype=torch.uint8) + for slot in self.__slots__: + setattr(self, slot, torch.tensor([])) + + self.device = device + + def __len__(self): + return len(getattr(self, self.__slots__[0])) + + def __getattribute__(self, item): + attribute = super().__getattribute__(item) + if isinstance(attribute, LabelTensor) and attribute.dtype == torch.float32: + attribute = attribute.to(device=self.device).requires_grad_() + return attribute + + def __getitem__(self, idx): + if isinstance(idx, str): + return getattr(self, idx).to(self.device) + + if isinstance(idx, slice): + to_return_list = [] + for i in self.__slots__: + to_return_list.append(getattr(self, i)[[idx]].to(self.device)) + return to_return_list + + if isinstance(idx, (tuple, list)): + if (len(idx) == 2 and isinstance(idx[0], str) + and isinstance(idx[1], (list, slice))): + tensor = getattr(self, idx[0]) + return tensor[[idx[1]]].to(self.device) + if all(isinstance(x, int) for x in idx): + to_return_list = [] + for i in self.__slots__: + to_return_list.append(getattr(self, i)[[idx]].to(self.device)) + return to_return_list + + raise ValueError(f'Invalid index {idx}') diff --git a/pina/data/data_dataset.py b/pina/data/data_dataset.py deleted file mode 100644 index 9dff2d7e..00000000 --- a/pina/data/data_dataset.py +++ /dev/null @@ -1,41 +0,0 @@ -from torch.utils.data import Dataset -import torch -from ..label_tensor import LabelTensor - - -class DataPointDataset(Dataset): - - def __init__(self, problem, device) -> None: - super().__init__() - input_list = [] - output_list = [] - self.condition_names = [] - - for name, condition in problem.conditions.items(): - if hasattr(condition, "output_points"): - input_list.append(problem.conditions[name].input_points) - output_list.append(problem.conditions[name].output_points) - self.condition_names.append(name) - - self.input_pts = LabelTensor.stack(input_list) - self.output_pts = LabelTensor.stack(output_list) - - if self.input_pts != []: - self.condition_indeces = torch.cat( - [ - torch.tensor([i] * len(input_list[i])) - for i in range(len(self.condition_names)) - ], - dim=0, - ) - else: # if there are no data points - self.condition_indeces = torch.tensor([]) - self.input_pts = torch.tensor([]) - self.output_pts = torch.tensor([]) - - self.input_pts = self.input_pts.to(device) - self.output_pts = self.output_pts.to(device) - self.condition_indeces = self.condition_indeces.to(device) - - def __len__(self): - return self.input_pts.shape[0] \ No newline at end of file diff --git a/pina/data/data_module.py b/pina/data/data_module.py new file mode 100644 index 00000000..e4e8a450 --- /dev/null +++ b/pina/data/data_module.py @@ -0,0 +1,172 @@ +""" +This module provide basic data management functionalities +""" + +import math +import torch +from lightning import LightningDataModule +from .sample_dataset import SamplePointDataset +from .supervised_dataset import SupervisedDataset +from .unsupervised_dataset import UnsupervisedDataset +from .pina_dataloader import PinaDataLoader +from .pina_subset import PinaSubset + + +class PinaDataModule(LightningDataModule): + """ + This class extend LightningDataModule, allowing proper creation and + management of different types of Datasets defined in PINA + """ + + def __init__(self, + problem, + device, + train_size=.7, + test_size=.2, + eval_size=.1, + batch_size=None, + shuffle=True, + datasets = None): + """ + Initialize the object, creating dataset based on input problem + :param AbstractProblem problem: PINA problem + :param device: Device used for training and testing + :param train_size: number/percentage of elements in train split + :param test_size: number/percentage of elements in test split + :param eval_size: number/percentage of elements in evaluation split + :param batch_size: batch size used for training + :param datasets: list of datasets objects + """ + super().__init__() + dataset_classes = [SupervisedDataset, UnsupervisedDataset, SamplePointDataset] + if datasets is None: + self.datasets = [DatasetClass(problem, device) for DatasetClass in dataset_classes] + else: + self.datasets = datasets + + self.split_length = [] + self.split_names = [] + if train_size > 0: + self.split_names.append('train') + self.split_length.append(train_size) + if test_size > 0: + self.split_length.append(test_size) + self.split_names.append('test') + if eval_size > 0: + self.split_length.append(eval_size) + self.split_names.append('eval') + + self.batch_size = batch_size + self.condition_names = None + self.splits = {k: {} for k in self.split_names} + self.shuffle = shuffle + + def setup(self, stage=None): + """ + Perform the splitting of the dataset + """ + self.extract_conditions() + if stage == 'fit' or stage is None: + for dataset in self.datasets: + if len(dataset) > 0: + splits = self.dataset_split(dataset, + self.split_length, + shuffle=self.shuffle) + for i in range(len(self.split_length)): + self.splits[ + self.split_names[i]][dataset.data_type] = splits[i] + elif stage == 'test': + raise NotImplementedError("Testing pipeline not implemented yet") + else: + raise ValueError("stage must be either 'fit' or 'test'") + + def extract_conditions(self): + """ + Extract conditions from dataset and update condition indices + """ + # Extract number of conditions + n_conditions = 0 + for dataset in self.datasets: + if n_conditions != 0: + dataset.condition_names = { + key + n_conditions: value + for key, value in dataset.condition_names.items() + } + n_conditions += len(dataset.condition_names) + + self.condition_names = { + key: value + for dataset in self.datasets + for key, value in dataset.condition_names.items() + } + + + + def train_dataloader(self): + """ + Return the training dataloader for the dataset + :return: data loader + :rtype: PinaDataLoader + """ + return PinaDataLoader(self.splits['train'], self.batch_size, + self.condition_names) + + def test_dataloader(self): + """ + Return the testing dataloader for the dataset + :return: data loader + :rtype: PinaDataLoader + """ + return PinaDataLoader(self.splits['test'], self.batch_size, + self.condition_names) + + def eval_dataloader(self): + """ + Return the evaluation dataloader for the dataset + :return: data loader + :rtype: PinaDataLoader + """ + return PinaDataLoader(self.splits['eval'], self.batch_size, + self.condition_names) + + @staticmethod + def dataset_split(dataset, lengths, seed=None, shuffle=True): + """ + Perform the splitting of the dataset + :param dataset: dataset object we wanted to split + :param lengths: lengths of elements in dataset + :param seed: random seed + :param shuffle: shuffle dataset + :return: split dataset + :rtype: PinaSubset + """ + if sum(lengths) - 1 < 1e-3: + lengths = [ + int(math.floor(len(dataset) * length)) for length in lengths + ] + + remainder = len(dataset) - sum(lengths) + for i in range(remainder): + lengths[i % len(lengths)] += 1 + elif sum(lengths) - 1 >= 1e-3: + raise ValueError(f"Sum of lengths is {sum(lengths)} less than 1") + + if sum(lengths) != len(dataset): + raise ValueError("Sum of lengths is not equal to dataset length") + + if shuffle: + if seed is not None: + generator = torch.Generator() + generator.manual_seed(seed) + indices = torch.randperm(sum(lengths), generator=generator).tolist() + else: + indices = torch.arange(sum(lengths)).tolist() + else: + indices = torch.arange(0, sum(lengths), 1, dtype=torch.uint8).tolist() + offsets = [ + sum(lengths[:i]) if i > 0 else 0 for i in range(len(lengths)) + ] + return [ + PinaSubset(dataset, indices[offset:offset + length]) + for offset, length in zip(offsets, lengths) + ] diff --git a/pina/data/pina_batch.py b/pina/data/pina_batch.py index cb1296ed..7e46a221 100644 --- a/pina/data/pina_batch.py +++ b/pina/data/pina_batch.py @@ -1,36 +1,33 @@ +""" +Batch management module +""" +from .pina_subset import PinaSubset class Batch: - """ - This class is used to create a dataset of sample points. - """ + def __init__(self, dataset_dict, idx_dict): - def __init__(self, type_, idx, *args, **kwargs) -> None: - """ - """ - if type_ == "sample": - - if len(args) != 2: - raise RuntimeError - - input = args[0] - conditions = args[1] - - self.input = input[idx] - self.condition = conditions[idx] + for k, v in dataset_dict.items(): + setattr(self, k, v) - elif type_ == "data": + for k, v in idx_dict.items(): + setattr(self, k + '_idx', v) - if len(args) != 3: - raise RuntimeError - - input = args[0] - output = args[1] - conditions = args[2] - - self.input = input[idx] - self.output = output[idx] - self.condition = conditions[idx] - - else: - raise ValueError("Invalid number of arguments.") \ No newline at end of file + def __len__(self): + """ + Returns the number of elements in the batch + :return: number of elements in the batch + :rtype: int + """ + length = 0 + for dataset in dir(self): + attribute = getattr(self, dataset) + if isinstance(attribute, list): + length += len(getattr(self, dataset)) + return length + + def __getattr__(self, item): + if not item in dir(self): + raise AttributeError(f'Batch instance has no attribute {item}') + return PinaSubset(getattr(self, item).dataset, + getattr(self, item).indices[self.coordinates_dict[item]]) diff --git a/pina/data/pina_dataloader.py b/pina/data/pina_dataloader.py index 2c8967c5..d6284757 100644 --- a/pina/data/pina_dataloader.py +++ b/pina/data/pina_dataloader.py @@ -1,11 +1,11 @@ -import torch - -from .sample_dataset import SamplePointDataset -from .data_dataset import DataPointDataset +""" +This module is used to create an iterable object used during training +""" +import math from .pina_batch import Batch -class SamplePointLoader: +class PinaDataLoader: """ This class is used to create a dataloader to use during the training. @@ -14,198 +14,54 @@ class SamplePointLoader: :vartype condition_names: list[str] """ - def __init__( - self, sample_dataset, data_dataset, batch_size=None, shuffle=True - ) -> None: - """ - Constructor. - - :param SamplePointDataset sample_pts: The sample points dataset. - :param int batch_size: The batch size. If ``None``, the batch size is - set to the number of sample points. Default is ``None``. - :param bool shuffle: If ``True``, the sample points are shuffled. - Default is ``True``. - """ - if not isinstance(sample_dataset, SamplePointDataset): - raise TypeError( - f"Expected SamplePointDataset, got {type(sample_dataset)}" - ) - if not isinstance(data_dataset, DataPointDataset): - raise TypeError( - f"Expected DataPointDataset, got {type(data_dataset)}" - ) - - self.n_data_conditions = len(data_dataset.condition_names) - self.n_phys_conditions = len(sample_dataset.condition_names) - data_dataset.condition_indeces += self.n_phys_conditions - - self._prepare_sample_dataset(sample_dataset, batch_size, shuffle) - self._prepare_data_dataset(data_dataset, batch_size, shuffle) - - self.condition_names = ( - sample_dataset.condition_names + data_dataset.condition_names - ) - - self.batch_list = [] - for i in range(len(self.batch_sample_pts)): - self.batch_list.append(("sample", i)) - - for i in range(len(self.batch_input_pts)): - self.batch_list.append(("data", i)) - - if shuffle: - self.random_idx = torch.randperm(len(self.batch_list)) - else: - self.random_idx = torch.arange(len(self.batch_list)) - - self._prepare_batches() - - def _prepare_data_dataset(self, dataset, batch_size, shuffle): - """ - Prepare the dataset for data points. - - :param SamplePointDataset dataset: The dataset. - :param int batch_size: The batch size. - :param bool shuffle: If ``True``, the sample points are shuffled. - """ - self.sample_dataset = dataset - - if len(dataset) == 0: - self.batch_data_conditions = [] - self.batch_input_pts = [] - self.batch_output_pts = [] - return - - if batch_size is None: - batch_size = len(dataset) - batch_num = len(dataset) // batch_size - if len(dataset) % batch_size != 0: - batch_num += 1 - - output_labels = dataset.output_pts.labels - input_labels = dataset.input_pts.labels - self.tensor_conditions = dataset.condition_indeces - - if shuffle: - idx = torch.randperm(dataset.input_pts.shape[0]) - self.input_pts = dataset.input_pts[idx] - self.output_pts = dataset.output_pts[idx] - self.tensor_conditions = dataset.condition_indeces[idx] - - self.batch_input_pts = torch.tensor_split(dataset.input_pts, batch_num) - self.batch_output_pts = torch.tensor_split( - dataset.output_pts, batch_num - ) - #print(input_labels) - for i in range(len(self.batch_input_pts)): - self.batch_input_pts[i].labels = input_labels - self.batch_output_pts[i].labels = output_labels - - self.batch_data_conditions = torch.tensor_split( - self.tensor_conditions, batch_num - ) - - def _prepare_sample_dataset(self, dataset, batch_size, shuffle): + def __init__(self, dataset_dict, batch_size, condition_names) -> None: """ - Prepare the dataset for sample points. - - :param DataPointDataset dataset: The dataset. - :param int batch_size: The batch size. - :param bool shuffle: If ``True``, the sample points are shuffled. + Initialize local variables + :param dataset_dict: Dictionary of datasets + :type dataset_dict: dict + :param batch_size: Size of the batch + :type batch_size: int + :param condition_names: Names of the conditions + :type condition_names: list[str] """ + self.condition_names = condition_names + self.dataset_dict = dataset_dict + self._init_batches(batch_size) - self.sample_dataset = dataset - if len(dataset) == 0: - self.batch_sample_conditions = [] - self.batch_sample_pts = [] - return - - if batch_size is None: - batch_size = len(dataset) - - batch_num = len(dataset) // batch_size - if len(dataset) % batch_size != 0: - batch_num += 1 - - self.tensor_pts = dataset.pts - self.tensor_conditions = dataset.condition_indeces - - # if shuffle: - # idx = torch.randperm(self.tensor_pts.shape[0]) - # self.tensor_pts = self.tensor_pts[idx] - # self.tensor_conditions = self.tensor_conditions[idx] - - self.batch_sample_pts = torch.tensor_split(self.tensor_pts, batch_num) - for i in range(len(self.batch_sample_pts)): - self.batch_sample_pts[i].labels = dataset.pts.labels - - self.batch_sample_conditions = torch.tensor_split( - self.tensor_conditions, batch_num - ) - - def _prepare_batches(self): + def _init_batches(self, batch_size=None): """ - Prepare the batches. + Create batches according to the batch_size provided in input. """ self.batches = [] - for i in range(len(self.batch_list)): - type_, idx_ = self.batch_list[i] - - if type_ == "sample": - batch = Batch( - "sample", idx_, - self.batch_sample_pts, - self.batch_sample_conditions) + n_elements = sum([len(v) for v in self.dataset_dict.values()]) + if batch_size is None: + batch_size = n_elements + indexes_dict = {} + n_batches = int(math.ceil(n_elements / batch_size)) + for k, v in self.dataset_dict.items(): + if n_batches != 1: + indexes_dict[k] = math.floor(len(v) / (n_batches - 1)) else: - batch = Batch( - "data", idx_, - self.batch_input_pts, - self.batch_output_pts, - self.batch_data_conditions) - - self.batches.append(batch) + indexes_dict[k] = len(v) + for i in range(n_batches): + temp_dict = {} + for k, v in indexes_dict.items(): + if i != n_batches - 1: + temp_dict[k] = slice(i * v, (i + 1) * v) + else: + temp_dict[k] = slice(i * v, len(self.dataset_dict[k])) + self.batches.append(Batch(idx_dict=temp_dict, dataset_dict=self.dataset_dict)) def __iter__(self): """ - Return an iterator over the points. Any element of the iterator is a - dictionary with the following keys: - - ``pts``: The input sample points. It is a LabelTensor with the - shape ``(batch_size, input_dimension)``. - - ``output``: The output sample points. This key is present only - if data conditions are present. It is a LabelTensor with the - shape ``(batch_size, output_dimension)``. - - ``condition``: The integer condition indeces. It is a tensor - with the shape ``(batch_size, )`` of type ``torch.int64`` and - indicates for any ``pts`` the corresponding problem condition. - - :return: An iterator over the points. - :rtype: iter + Makes dataloader object iterable """ - # for i in self.random_idx: - for i in self.random_idx: - yield self.batches[i] - - # for i in range(len(self.batch_list)): - # type_, idx_ = self.batch_list[i] - - # if type_ == "sample": - # d = { - # "pts": self.batch_sample_pts[idx_].requires_grad_(True), - # "condition": self.batch_sample_conditions[idx_], - # } - # else: - # d = { - # "pts": self.batch_input_pts[idx_].requires_grad_(True), - # "output": self.batch_output_pts[idx_], - # "condition": self.batch_data_conditions[idx_], - # } - # yield d + yield from self.batches def __len__(self): """ Return the number of batches. - :return: The number of batches. :rtype: int """ - return len(self.batch_list) + return len(self.batches) diff --git a/pina/data/pina_subset.py b/pina/data/pina_subset.py new file mode 100644 index 00000000..41571f92 --- /dev/null +++ b/pina/data/pina_subset.py @@ -0,0 +1,21 @@ +class PinaSubset: + """ + TODO + """ + __slots__ = ['dataset', 'indices'] + + def __init__(self, dataset, indices): + """ + TODO + """ + self.dataset = dataset + self.indices = indices + + def __len__(self): + """ + TODO + """ + return len(self.indices) + + def __getattr__(self, name): + return self.dataset.__getattribute__(name) diff --git a/pina/data/sample_dataset.py b/pina/data/sample_dataset.py index 84af2920..ba8bd19a 100644 --- a/pina/data/sample_dataset.py +++ b/pina/data/sample_dataset.py @@ -1,43 +1,12 @@ -from torch.utils.data import Dataset -import torch +""" +Sample dataset module +""" +from .base_dataset import BaseDataset -from ..label_tensor import LabelTensor - - -class SamplePointDataset(Dataset): +class SamplePointDataset(BaseDataset): """ - This class is used to create a dataset of sample points. + This class extends the BaseDataset to handle physical datasets + composed of only input points. """ - - def __init__(self, problem, device) -> None: - """ - :param dict input_pts: The input points. - """ - super().__init__() - pts_list = [] - self.condition_names = [] - - for name, condition in problem.conditions.items(): - if not hasattr(condition, "output_points"): - pts_list.append(problem.input_pts[name]) - self.condition_names.append(name) - - self.pts = LabelTensor.stack(pts_list) - - if self.pts != []: - self.condition_indeces = torch.cat( - [ - torch.tensor([i] * len(pts_list[i])) - for i in range(len(self.condition_names)) - ], - dim=0, - ) - else: # if there are no sample points - self.condition_indeces = torch.tensor([]) - self.pts = torch.tensor([]) - - self.pts = self.pts.to(device) - self.condition_indeces = self.condition_indeces.to(device) - - def __len__(self): - return self.pts.shape[0] \ No newline at end of file + data_type = 'physics' + __slots__ = ['input_points'] diff --git a/pina/data/supervised_dataset.py b/pina/data/supervised_dataset.py new file mode 100644 index 00000000..2403e3d0 --- /dev/null +++ b/pina/data/supervised_dataset.py @@ -0,0 +1,12 @@ +""" +Supervised dataset module +""" +from .base_dataset import BaseDataset + + +class SupervisedDataset(BaseDataset): + """ + This class extends the BaseDataset to handle datasets that consist of input-output pairs. + """ + data_type = 'supervised' + __slots__ = ['input_points', 'output_points'] diff --git a/pina/data/unsupervised_dataset.py b/pina/data/unsupervised_dataset.py new file mode 100644 index 00000000..f4e8fb34 --- /dev/null +++ b/pina/data/unsupervised_dataset.py @@ -0,0 +1,13 @@ +""" +Unsupervised dataset module +""" +from .base_dataset import BaseDataset + + +class UnsupervisedDataset(BaseDataset): + """ + This class extend BaseDataset class to handle unsupervised dataset, + composed of input points and, optionally, conditional variables + """ + data_type = 'unsupervised' + __slots__ = ['input_points', 'conditional_variables'] diff --git a/pina/domain/cartesian.py b/pina/domain/cartesian.py index 6e9b81af..4986ea7e 100644 --- a/pina/domain/cartesian.py +++ b/pina/domain/cartesian.py @@ -33,7 +33,7 @@ def __init__(self, cartesian_dict): @property def sample_modes(self): return ["random", "grid", "lh", "chebyshev", "latin"] - + @property def variables(self): """Spatial variables. diff --git a/pina/domain/ellipsoid.py b/pina/domain/ellipsoid.py index b9185fa0..18e28d4b 100644 --- a/pina/domain/ellipsoid.py +++ b/pina/domain/ellipsoid.py @@ -55,7 +55,6 @@ def __init__(self, ellipsoid_dict, sample_surface=False): # perform operation only for not fixed variables (if any) if self.range_: - # convert dict vals to torch [dim, 2] matrix list_dict_vals = list(self.range_.values()) tmp = torch.tensor(list_dict_vals, dtype=torch.float) @@ -74,7 +73,7 @@ def __init__(self, ellipsoid_dict, sample_surface=False): @property def sample_modes(self): return ["random"] - + @property def variables(self): """Spatial variables. diff --git a/pina/domain/operation_interface.py b/pina/domain/operation_interface.py index a1efec91..0300f524 100644 --- a/pina/domain/operation_interface.py +++ b/pina/domain/operation_interface.py @@ -69,4 +69,4 @@ def _check_dimensions(self, geometries): if geometry.variables != geometries[0].variables: raise NotImplementedError( f"The geometries need to have same dimensions and labels." - ) \ No newline at end of file + ) diff --git a/pina/domain/simplex.py b/pina/domain/simplex.py index 96cc36c0..931f861a 100644 --- a/pina/domain/simplex.py +++ b/pina/domain/simplex.py @@ -77,7 +77,7 @@ def __init__(self, simplex_matrix, sample_surface=False): @property def sample_modes(self): return ["random"] - + @property def variables(self): return self._vertices_matrix.labels @@ -144,7 +144,7 @@ def is_inside(self, point, check_border=False): return all(torch.gt(lambdas, 0.0)) and all(torch.lt(lambdas, 1.0)) return all(torch.ge(lambdas, 0)) and ( - any(torch.eq(lambdas, 0)) or any(torch.eq(lambdas, 1)) + any(torch.eq(lambdas, 0)) or any(torch.eq(lambdas, 1)) ) def _sample_interior_randomly(self, n, variables): diff --git a/pina/domain/union_domain.py b/pina/domain/union_domain.py index a72115f5..0af8e1bd 100644 --- a/pina/domain/union_domain.py +++ b/pina/domain/union_domain.py @@ -37,13 +37,13 @@ def __init__(self, geometries): def sample_modes(self): self.sample_modes = list( set([geom.sample_modes for geom in self.geometries]) - ) - + ) + @property def variables(self): variables = [] for geom in self.geometries: - variables+=geom.variables + variables += geom.variables return list(set(variables)) def is_inside(self, point, check_border=False): diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 1df318ec..65655e9d 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -3,6 +3,7 @@ import torch from torch import Tensor + def issubset(a, b): """ Check if a is a subset of b. @@ -45,7 +46,7 @@ def labels(self): :return: labels of self :rtype: list """ - return self._labels[self.tensor.ndim-1]['dof'] + return self._labels[self.tensor.ndim - 1]['dof'] @property def full_labels(self): @@ -103,23 +104,23 @@ def extract(self, label_to_extract): raise ValueError('labels_to_extract must be str or list or dict') def _extract_from_list(self, labels_to_extract): - #Store locally all necessary obj/variables + # Store locally all necessary obj/variables ndim = self.tensor.ndim labels = self.full_labels tensor = self.tensor last_dim_label = self.labels - #Verify if all the labels in labels_to_extract are in last dimension + # Verify if all the labels in labels_to_extract are in last dimension if set(labels_to_extract).issubset(last_dim_label) is False: raise ValueError('Cannot extract a dof which is not in the original LabelTensor') - #Extract index to extract + # Extract index to extract idx_to_extract = [last_dim_label.index(i) for i in labels_to_extract] - #Perform extraction + # Perform extraction new_tensor = tensor[..., idx_to_extract] - #Manage labels + # Manage labels new_labels = copy(labels) last_dim_new_label = {ndim - 1: { @@ -186,7 +187,7 @@ def cat(tensors, dim=0): # Perform cat on tensors new_tensor = torch.cat(tensors, dim=dim) - #Update labels + # Update labels labels = tensors[0].full_labels labels.pop(dim) new_labels_cat_dim = new_labels_cat_dim if len(set(new_labels_cat_dim)) == len(new_labels_cat_dim) \ @@ -265,13 +266,13 @@ def update_labels_from_dict(self, labels): :raises ValueError: dof list contain duplicates or number of dof does not match with tensor shape """ tensor_shape = self.tensor.shape - #Check dimensionality + # Check dimensionality for k, v in labels.items(): if len(v['dof']) != len(set(v['dof'])): raise ValueError("dof must be unique") if len(v['dof']) != tensor_shape[k]: raise ValueError('Number of dof does not match with tensor dimension') - #Perform update + # Perform update self._labels.update(labels) def update_labels_from_list(self, labels): @@ -310,7 +311,7 @@ def append(self, tensor, mode='std'): if mode == 'std': # Call cat on last dimension new_label_tensor = LabelTensor.cat([self, tensor], dim=self.tensor.ndim - 1) - elif mode=='cross': + elif mode == 'cross': # Crete tensor and call cat on last dimension tensor1 = self tensor2 = tensor @@ -318,7 +319,7 @@ def append(self, tensor, mode='std'): n2 = tensor2.shape[0] tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels) - new_label_tensor = LabelTensor.cat([tensor1, tensor2], dim=self.tensor.ndim-1) + new_label_tensor = LabelTensor.cat([tensor1, tensor2], dim=self.tensor.ndim - 1) else: raise ValueError('mode must be either "std" or "cross"') return new_label_tensor @@ -366,10 +367,10 @@ def __getitem__(self, index): if hasattr(self, "labels"): if isinstance(index[j], list): new_labels.update({j: {'dof': [new_labels[j]['dof'][i] for i in index[1]], - 'name':new_labels[j]['name']}}) + 'name': new_labels[j]['name']}}) else: new_labels.update({j: {'dof': new_labels[j]['dof'][index[j]], - 'name':new_labels[j]['name']}}) + 'name': new_labels[j]['name']}}) selected_lt.labels = new_labels else: @@ -382,12 +383,13 @@ def __getitem__(self, index): def sort_labels(self, dim=None): def argsort(lst): return sorted(range(len(lst)), key=lambda x: lst[x]) + if dim is None: - dim = self.tensor.ndim-1 + dim = self.tensor.ndim - 1 labels = self.full_labels[dim]['dof'] sorted_index = argsort(labels) indexer = [slice(None)] * self.tensor.ndim indexer[dim] = sorted_index new_labels = deepcopy(self.full_labels) new_labels[dim] = {'dof': sorted(labels), 'name': new_labels[dim]['name']} - return LabelTensor(self.tensor[indexer], new_labels) \ No newline at end of file + return LabelTensor(self.tensor[indexer], new_labels) diff --git a/pina/operators.py b/pina/operators.py index fa32f292..9e780ec8 100644 --- a/pina/operators.py +++ b/pina/operators.py @@ -211,7 +211,7 @@ def laplacian(output_, input_, components=None, d=None, method="std"): result[:, idx] = grad(grad_output, input_, d=di).flatten() to_append_tensors[idx] = grad(grad_output, input_, d=di) labels[idx] = f"dd{ci[0]}dd{di[0]}" - result = LabelTensor.cat(tensors=to_append_tensors, dim=output_.tensor.ndim-1) + result = LabelTensor.cat(tensors=to_append_tensors, dim=output_.tensor.ndim - 1) result.labels = labels return result diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index 9e841d65..543f823f 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -27,13 +27,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): """ def __init__( - self, - models, - problem, - optimizers, - optimizers_kwargs, - extra_features, - loss, + self, + models, + problem, + optimizers, + optimizers_kwargs, + extra_features, + loss, ): """ :param models: Multiple torch neural network models instances. @@ -177,7 +177,7 @@ def compute_residual(self, samples, equation): try: residual = equation.residual(samples, self.forward(samples)) except ( - TypeError + TypeError ): # this occurs when the function has three inputs, i.e. inverse problem residual = equation.residual( samples, self.forward(samples), self._params diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index a27e9364..8b3ddae7 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -10,168 +10,6 @@ import sys -# class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): -# """ -# Solver base class. This class inherits is a wrapper of -# LightningModule class, inheriting all the -# LightningModule methods. -# """ - -# def __init__( -# self, -# models, -# problem, -# optimizers, -# optimizers_kwargs, -# extra_features=None, -# ): -# """ -# :param models: A torch neural network model instance. -# :type models: torch.nn.Module -# :param problem: A problem definition instance. -# :type problem: AbstractProblem -# :param list(torch.optim.Optimizer) optimizer: A list of neural network optimizers to -# use. -# :param list(dict) optimizer_kwargs: A list of optimizer constructor keyword args. -# :param list(torch.nn.Module) extra_features: The additional input -# features to use as augmented input. If ``None`` no extra features -# are passed. If it is a list of :class:`torch.nn.Module`, the extra feature -# list is passed to all models. If it is a list of extra features' lists, -# each single list of extra feature is passed to a model. -# """ -# super().__init__() - -# # check consistency of the inputs -# check_consistency(models, torch.nn.Module) -# check_consistency(problem, AbstractProblem) -# check_consistency(optimizers, torch.optim.Optimizer, subclass=True) -# check_consistency(optimizers_kwargs, dict) - -# # put everything in a list if only one input -# if not isinstance(models, list): -# models = [models] -# if not isinstance(optimizers, list): -# optimizers = [optimizers] -# optimizers_kwargs = [optimizers_kwargs] - -# # number of models and optimizers -# len_model = len(models) -# len_optimizer = len(optimizers) -# len_optimizer_kwargs = len(optimizers_kwargs) - -# # check length consistency optimizers -# if len_model != len_optimizer: -# raise ValueError( -# "You must define one optimizer for each model." -# f"Got {len_model} models, and {len_optimizer}" -# " optimizers." -# ) - -# # check length consistency optimizers kwargs -# if len_optimizer_kwargs != len_optimizer: -# raise ValueError( -# "You must define one dictionary of keyword" -# " arguments for each optimizers." -# f"Got {len_optimizer} optimizers, and" -# f" {len_optimizer_kwargs} dicitionaries" -# ) - -# # extra features handling -# if (extra_features is None) or (len(extra_features) == 0): -# extra_features = [None] * len_model -# else: -# # if we only have a list of extra features -# if not isinstance(extra_features[0], (tuple, list)): -# extra_features = [extra_features] * len_model -# else: # if we have a list of list extra features -# if len(extra_features) != len_model: -# raise ValueError( -# "You passed a list of extrafeatures list with len" -# f"different of models len. Expected {len_model} " -# f"got {len(extra_features)}. If you want to use " -# "the same list of extra features for all models, " -# "just pass a list of extrafeatures and not a list " -# "of list of extra features." -# ) - -# # assigning model and optimizers -# self._pina_models = [] -# self._pina_optimizers = [] - -# for idx in range(len_model): -# model_ = Network( -# model=models[idx], -# input_variables=problem.input_variables, -# output_variables=problem.output_variables, -# extra_features=extra_features[idx], -# ) -# optim_ = optimizers[idx]( -# model_.parameters(), **optimizers_kwargs[idx] -# ) -# self._pina_models.append(model_) -# self._pina_optimizers.append(optim_) - -# # assigning problem -# self._pina_problem = problem - -# @abstractmethod -# def forward(self, *args, **kwargs): -# pass - -# @abstractmethod -# def training_step(self): -# pass - -# @abstractmethod -# def configure_optimizers(self): -# pass - -# @property -# def models(self): -# """ -# The torch model.""" -# return self._pina_models - -# @property -# def optimizers(self): -# """ -# The torch model.""" -# return self._pina_optimizers - -# @property -# def problem(self): -# """ -# The problem formulation.""" -# return self._pina_problem - -# def on_train_start(self): -# """ -# On training epoch start this function is call to do global checks for -# the different solvers. -# """ - -# # 1. Check the verison for dataloader -# dataloader = self.trainer.train_dataloader -# if sys.version_info < (3, 8): -# dataloader = dataloader.loaders -# self._dataloader = dataloader - -# return super().on_train_start() - - # @model.setter - # def model(self, new_model): - # """ - # Set the torch.""" - # check_consistency(new_model, nn.Module, 'torch model') - # self._model= new_model - - # @problem.setter - # def problem(self, problem): - # """ - # Set the problem formulation.""" - # check_consistency(problem, AbstractProblem, 'pina problem') - # self._problem = problem - class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): """ Solver base class. This class inherits is a wrapper of @@ -181,10 +19,12 @@ class SolverInterface(pytorch_lightning.LightningModule, metaclass=ABCMeta): def __init__( self, - model, + models, problem, - optimizer, - scheduler, + optimizers, + schedulers, + extra_features, + use_lt=True ): """ :param model: A torch neural network model instance. @@ -197,22 +37,45 @@ def __init__( super().__init__() # check consistency of the inputs - check_consistency(model, torch.nn.Module) check_consistency(problem, AbstractProblem) - check_consistency(optimizer, Optimizer) - check_consistency(scheduler, Scheduler) - - # put everything in a list if only one input - if not isinstance(model, list): - model = [model] - if not isinstance(scheduler, list): - scheduler = [scheduler] - if not isinstance(optimizer, list): - optimizer = [optimizer] - - # number of models and optimizers - len_model = len(model) - len_optimizer = len(optimizer) + self._check_solver_consistency(problem) + + #Check consistency of models argument and encapsulate in list + if not isinstance(models, list): + check_consistency(models, torch.nn.Module) + # put everything in a list if only one input + models = [models] + else: + for idx in range(len(models)): + # Check consistency + check_consistency(models[idx], torch.nn.Module) + len_model = len(models) + + #If use_lt is true add extract operation in input + if use_lt is True: + for idx in range(len(models)): + models[idx] = Network( + model = models[idx], + input_variables=problem.input_variables, + output_variables=problem.output_variables, + extra_features=extra_features, ) + + #Check scheduler consistency + encapsulation + if not isinstance(schedulers, list): + check_consistency(schedulers, Scheduler) + schedulers = [schedulers] + else: + for scheduler in schedulers: + check_consistency(scheduler, Scheduler) + + #Check optimizer consistency + encapsulation + if not isinstance(optimizers, list): + check_consistency(optimizers, Optimizer) + optimizers = [optimizers] + else: + for optimizer in optimizers: + check_consistency(optimizer, Optimizer) + len_optimizer = len(optimizers) # check length consistency optimizers if len_model != len_optimizer: @@ -223,10 +86,12 @@ def __init__( ) # extra features handling + + self._pina_models = models + self._pina_optimizers = optimizers + self._pina_schedulers = schedulers self._pina_problem = problem - self._pina_model = model - self._pina_optimizer = optimizer - self._pina_scheduler = scheduler + @abstractmethod def forward(self, *args, **kwargs): @@ -244,13 +109,13 @@ def configure_optimizers(self): def models(self): """ The torch model.""" - return self._pina_model + return self._pina_models @property def optimizers(self): """ The torch model.""" - return self._pina_optimizer + return self._pina_optimizers @property def problem(self): @@ -272,16 +137,10 @@ def on_train_start(self): return super().on_train_start() - # @model.setter - # def model(self, new_model): - # """ - # Set the torch.""" - # check_consistency(new_model, nn.Module, 'torch model') - # self._model= new_model - - # @problem.setter - # def problem(self, problem): - # """ - # Set the problem formulation.""" - # check_consistency(problem, AbstractProblem, 'pina problem') - # self._problem = problem + def _check_solver_consistency(self, problem): + """ + TODO + """ + for _, condition in problem.conditions.items(): + if not set(self.accepted_condition_types).issubset(condition.condition_type): + raise ValueError(f'{self.__name__} support only dose not support condition {condition.condition_type}') diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index c44d5a1e..32f687ed 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -2,9 +2,7 @@ import torch from torch.nn.modules.loss import _Loss - - -from ..optim import Optimizer, Scheduler, TorchOptimizer, TorchScheduler +from ..optim import TorchOptimizer, TorchScheduler from .solver import SolverInterface from ..label_tensor import LabelTensor from ..utils import check_consistency @@ -39,14 +37,17 @@ class SupervisedSolver(SolverInterface): we are seeking to approximate multiple (discretised) functions given multiple (discretised) input functions. """ + accepted_condition_types = ['supervised'] + __name__ = 'SupervisedSolver' def __init__( - self, - problem, - model, - loss=None, - optimizer=None, - scheduler=None, + self, + problem, + model, + loss=None, + optimizer=None, + scheduler=None, + extra_features=None ): """ :param AbstractProblem problem: The formualation of the problem. @@ -57,11 +58,8 @@ def __init__( features to use as augmented input. :param torch.optim.Optimizer optimizer: The neural network optimizer to use; default is :class:`torch.optim.Adam`. - :param dict optimizer_kwargs: Optimizer constructor keyword args. - :param float lr: The learning rate; default is 0.001. :param torch.optim.LRScheduler scheduler: Learning rate scheduler. - :param dict scheduler_kwargs: LR scheduler constructor keyword args. """ if loss is None: loss = torch.nn.MSELoss() @@ -74,18 +72,19 @@ def __init__( torch.optim.lr_scheduler.ConstantLR) super().__init__( - model=model, + models=model, problem=problem, - optimizer=optimizer, - scheduler=scheduler, + optimizers=optimizer, + schedulers=scheduler, + extra_features=extra_features ) # check consistency check_consistency(loss, (LossInterface, _Loss), subclass=False) self._loss = loss - self._model = self._pina_model[0] - self._optimizer = self._pina_optimizer[0] - self._scheduler = self._pina_scheduler[0] + self._model = self._pina_models[0] + self._optimizer = self._pina_optimizers[0] + self._scheduler = self._pina_schedulers[0] def forward(self, x): """Forward pass implementation for the solver. @@ -97,12 +96,7 @@ def forward(self, x): output = self._model(x) - output.labels = { - 1: { - "name": "output", - "dof": self.problem.output_variables - } - } + output.labels = self.problem.output_variables return output def configure_optimizers(self): @@ -128,16 +122,14 @@ def training_step(self, batch, batch_idx): :return: The sum of the loss functions. :rtype: LabelTensor """ - - condition_idx = batch.condition + condition_idx = batch.supervised.condition_indices for condition_id in range(condition_idx.min(), condition_idx.max() + 1): condition_name = self._dataloader.condition_names[condition_id] condition = self.problem.conditions[condition_name] - pts = batch.input - out = batch.output - + pts = batch.supervised.input_points + out = batch.supervised.output_points if condition_name not in self.problem.conditions: raise RuntimeError("Something wrong happened.") @@ -167,8 +159,8 @@ def loss_data(self, input_pts, output_pts): the network output against the true solution. This function should not be override if not intentionally. - :param LabelTensor input_tensor: The input to the neural networks. - :param LabelTensor output_tensor: The true solution to compare the + :param LabelTensor input_pts: The input to the neural networks. + :param LabelTensor output_pts: The true solution to compare the network solution. :return: The residual loss averaged on the input coordinates :rtype: torch.Tensor @@ -181,7 +173,7 @@ def scheduler(self): Scheduler for training. """ return self._scheduler - + @property def optimizer(self): """ diff --git a/pina/trainer.py b/pina/trainer.py index ba18f339..49c6a401 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -3,13 +3,13 @@ import torch import pytorch_lightning from .utils import check_consistency -from .data import SamplePointDataset, SamplePointLoader, DataPointDataset +from .data import PinaDataModule from .solvers.solver import SolverInterface class Trainer(pytorch_lightning.Trainer): - def __init__(self, solver, batch_size=None, **kwargs): + def __init__(self, solver, batch_size=None, train_size=.7, test_size=.2, eval_size=.1, **kwargs): """ PINA Trainer class for costumizing every aspect of training via flags. @@ -31,10 +31,11 @@ def __init__(self, solver, batch_size=None, **kwargs): check_consistency(solver, SolverInterface) if batch_size is not None: check_consistency(batch_size, int) - + self.train_size = train_size + self.test_size = test_size + self.eval_size = eval_size self.solver = solver self.batch_size = batch_size - self._create_loader() self._move_to_device() @@ -69,11 +70,12 @@ def _create_loader(self): raise RuntimeError("Parallel training is not supported yet.") device = devices[0] - dataset_phys = SamplePointDataset(self.solver.problem, device) - dataset_data = DataPointDataset(self.solver.problem, device) - self._loader = SamplePointLoader( - dataset_phys, dataset_data, batch_size=self.batch_size, shuffle=True - ) + + data_module = PinaDataModule(problem=self.solver.problem, device=device, + train_size=self.train_size, test_size=self.test_size, + eval_size=self.eval_size) + data_module.setup() + self._loader = data_module.train_dataloader() def train(self, **kwargs): """ @@ -89,3 +91,7 @@ def solver(self): Returning trainer solver. """ return self._solver + + @solver.setter + def solver(self, solver): + self._solver = solver diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 40f21922..264f794b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,44 +1,45 @@ +import math import torch -import pytest - -from pina.data.dataset import SamplePointDataset, SamplePointLoader, DataPointDataset +from pina.data import SamplePointDataset, SupervisedDataset, PinaDataModule, UnsupervisedDataset, unsupervised_dataset +from pina.data import PinaDataLoader from pina import LabelTensor, Condition from pina.equation import Equation from pina.domain import CartesianDomain from pina.problem import SpatialProblem -from pina.model import FeedForward from pina.operators import laplacian from pina.equation.equation_factory import FixedValue def laplace_equation(input_, output_): - force_term = (torch.sin(input_.extract(['x'])*torch.pi) * - torch.sin(input_.extract(['y'])*torch.pi)) + force_term = (torch.sin(input_.extract(['x']) * torch.pi) * + torch.sin(input_.extract(['y']) * torch.pi)) delta_u = laplacian(output_.extract(['u']), input_) return delta_u - force_term + my_laplace = Equation(laplace_equation) in_ = LabelTensor(torch.tensor([[0., 1.]]), ['x', 'y']) out_ = LabelTensor(torch.tensor([[0.]]), ['u']) in2_ = LabelTensor(torch.rand(60, 2), ['x', 'y']) out2_ = LabelTensor(torch.rand(60, 1), ['u']) + class Poisson(SpatialProblem): output_variables = ['u'] spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) conditions = { 'gamma1': Condition( - location=CartesianDomain({'x': [0, 1], 'y': 1}), + domain=CartesianDomain({'x': [0, 1], 'y': 1}), equation=FixedValue(0.0)), 'gamma2': Condition( - location=CartesianDomain({'x': [0, 1], 'y': 0}), + domain=CartesianDomain({'x': [0, 1], 'y': 0}), equation=FixedValue(0.0)), 'gamma3': Condition( - location=CartesianDomain({'x': 1, 'y': [0, 1]}), + domain=CartesianDomain({'x': 1, 'y': [0, 1]}), equation=FixedValue(0.0)), 'gamma4': Condition( - location=CartesianDomain({'x': 0, 'y': [0, 1]}), + domain=CartesianDomain({'x': 0, 'y': [0, 1]}), equation=FixedValue(0.0)), 'D': Condition( input_points=LabelTensor(torch.rand(size=(100, 2)), ['x', 'y']), @@ -48,75 +49,114 @@ class Poisson(SpatialProblem): output_points=out_), 'data2': Condition( input_points=in2_, - output_points=out2_) + output_points=out2_), + 'unsupervised': Condition( + input_points=LabelTensor(torch.rand(size=(45, 2)), ['x', 'y']), + conditional_variables=LabelTensor(torch.ones(size=(45, 1)), ['alpha']), + ), + 'unsupervised2': Condition( + input_points=LabelTensor(torch.rand(size=(90, 2)), ['x', 'y']), + conditional_variables=LabelTensor(torch.ones(size=(90, 1)), ['alpha']), + ) } + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] poisson = Poisson() poisson.discretise_domain(10, 'grid', locations=boundaries) + def test_sample(): sample_dataset = SamplePointDataset(poisson, device='cpu') assert len(sample_dataset) == 140 - assert sample_dataset.pts.shape == (140, 2) - assert sample_dataset.pts.labels == ['x', 'y'] - assert sample_dataset.condition_indeces.dtype == torch.int64 - assert sample_dataset.condition_indeces.max() == torch.tensor(4) - assert sample_dataset.condition_indeces.min() == torch.tensor(0) + assert sample_dataset.input_points.shape == (140, 2) + assert sample_dataset.input_points.labels == ['x', 'y'] + assert sample_dataset.condition_indices.dtype == torch.uint8 + assert sample_dataset.condition_indices.max() == torch.tensor(4) + assert sample_dataset.condition_indices.min() == torch.tensor(0) + def test_data(): - dataset = DataPointDataset(poisson, device='cpu') + dataset = SupervisedDataset(poisson, device='cpu') assert len(dataset) == 61 - assert dataset.input_pts.shape == (61, 2) - assert dataset.input_pts.labels == ['x', 'y'] - assert dataset.output_pts.shape == (61, 1 ) - assert dataset.output_pts.labels == ['u'] - assert dataset.condition_indeces.dtype == torch.int64 - assert dataset.condition_indeces.max() == torch.tensor(1) - assert dataset.condition_indeces.min() == torch.tensor(0) - -def test_loader(): - sample_dataset = SamplePointDataset(poisson, device='cpu') - data_dataset = DataPointDataset(poisson, device='cpu') - loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10) - + assert dataset['input_points'].shape == (61, 2) + assert dataset.input_points.shape == (61, 2) + assert dataset['input_points'].labels == ['x', 'y'] + assert dataset.input_points.labels == ['x', 'y'] + assert dataset['input_points', 3:].shape == (58, 2) + assert dataset[3:][1].labels == ['u'] + assert dataset.output_points.shape == (61, 1) + assert dataset.output_points.labels == ['u'] + assert dataset.condition_indices.dtype == torch.uint8 + assert dataset.condition_indices.max() == torch.tensor(1) + assert dataset.condition_indices.min() == torch.tensor(0) + + +def test_unsupervised(): + dataset = UnsupervisedDataset(poisson, device='cpu') + assert len(dataset) == 135 + assert dataset.input_points.shape == (135, 2) + assert dataset.input_points.labels == ['x', 'y'] + assert dataset.input_points[3:].shape == (132, 2) + + assert dataset.conditional_variables.shape == (135, 1) + assert dataset.conditional_variables.labels == ['alpha'] + assert dataset.condition_indices.dtype == torch.uint8 + assert dataset.condition_indices.max() == torch.tensor(1) + assert dataset.condition_indices.min() == torch.tensor(0) + + +def test_data_module(): + data_module = PinaDataModule(poisson, device='cpu') + data_module.setup() + loader = data_module.train_dataloader() + assert isinstance(loader, PinaDataLoader) + assert isinstance(loader, PinaDataLoader) + + data_module = PinaDataModule(poisson, device='cpu', batch_size=10, shuffle=False) + data_module.setup() + loader = data_module.train_dataloader() + assert len(loader) == 24 + for i in loader: + assert len(i) <= 10 + len_ref = sum([math.ceil(len(dataset) * 0.7) for dataset in data_module.datasets]) + len_real = sum([len(dataset) for dataset in data_module.splits['train'].values()]) + assert len_ref == len_real + + supervised_dataset = SupervisedDataset(poisson, device='cpu') + data_module = PinaDataModule(poisson, device='cpu', batch_size=10, shuffle=False, datasets=[supervised_dataset]) + data_module.setup() + loader = data_module.train_dataloader() for batch in loader: - assert len(batch) in [2, 3] - assert batch['pts'].shape[0] <= 10 - assert batch['pts'].requires_grad == True - assert batch['pts'].labels == ['x', 'y'] - - loader2 = SamplePointLoader(sample_dataset, data_dataset, batch_size=None) - assert len(list(loader2)) == 2 - -def test_loader2(): - poisson2 = Poisson() - del poisson.conditions['data2'] - del poisson2.conditions['data'] - poisson2.discretise_domain(10, 'grid', locations=boundaries) - sample_dataset = SamplePointDataset(poisson, device='cpu') - data_dataset = DataPointDataset(poisson, device='cpu') - loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10) + assert len(batch) <= 10 + physics_dataset = SamplePointDataset(poisson, device='cpu') + data_module = PinaDataModule(poisson, device='cpu', batch_size=10, shuffle=False, datasets=[physics_dataset]) + data_module.setup() + loader = data_module.train_dataloader() for batch in loader: - assert len(batch) == 2 # only phys condtions - assert batch['pts'].shape[0] <= 10 - assert batch['pts'].requires_grad == True - assert batch['pts'].labels == ['x', 'y'] - -def test_loader3(): - poisson2 = Poisson() - del poisson.conditions['gamma1'] - del poisson.conditions['gamma2'] - del poisson.conditions['gamma3'] - del poisson.conditions['gamma4'] - del poisson.conditions['D'] - sample_dataset = SamplePointDataset(poisson, device='cpu') - data_dataset = DataPointDataset(poisson, device='cpu') - loader = SamplePointLoader(sample_dataset, data_dataset, batch_size=10) + assert len(batch) <= 10 + unsupervised_dataset = UnsupervisedDataset(poisson, device='cpu') + data_module = PinaDataModule(poisson, device='cpu', batch_size=10, shuffle=False, datasets=[unsupervised_dataset]) + data_module.setup() + loader = data_module.train_dataloader() for batch in loader: - assert len(batch) == 2 # only phys condtions - assert batch['pts'].shape[0] <= 10 - assert batch['pts'].requires_grad == True - assert batch['pts'].labels == ['x', 'y'] + assert len(batch) <= 10 + + +def test_loader(): + data_module = PinaDataModule(poisson, device='cpu', batch_size=10) + data_module.setup() + loader = data_module.train_dataloader() + assert isinstance(loader, PinaDataLoader) + assert len(loader) == 24 + for i in loader: + assert len(i) <= 10 + assert i.supervised.input_points.labels == ['x', 'y'] + assert i.physics.input_points.labels == ['x', 'y'] + assert i.unsupervised.input_points.labels == ['x', 'y'] + assert i.supervised.input_points.requires_grad == True + assert i.physics.input_points.requires_grad == True + assert i.unsupervised.input_points.requires_grad == True +test_loader() \ No newline at end of file diff --git a/tests/test_solvers/test_supervised_solver.py b/tests/test_solvers/test_supervised_solver.py index 912480bb..8ceadcd9 100644 --- a/tests/test_solvers/test_supervised_solver.py +++ b/tests/test_solvers/test_supervised_solver.py @@ -1,51 +1,28 @@ import torch - -from pina.problem import AbstractProblem +import pytest +from pina.problem import AbstractProblem, SpatialProblem from pina import Condition, LabelTensor from pina.solvers import SupervisedSolver -from pina.trainer import Trainer from pina.model import FeedForward -from pina.loss import LpLoss -from pina.solvers import GraphSupervisedSolver +from pina.equation.equation import Equation +from pina.equation.equation_factory import FixedValue +from pina.operators import laplacian +from pina.domain import CartesianDomain +from pina.trainer import Trainer + +in_ = LabelTensor(torch.tensor([[0., 1.]]), ['u_0', 'u_1']) +out_ = LabelTensor(torch.tensor([[0.]]), ['u']) + class NeuralOperatorProblem(AbstractProblem): input_variables = ['u_0', 'u_1'] output_variables = ['u'] - domains = { - 'pts': LabelTensor( - torch.rand(100, 2), - labels={1: {'name': 'space', 'dof': ['u_0', 'u_1']}} - ) - } - conditions = { - 'data' : Condition( - domain='pts', - output_points=LabelTensor( - torch.rand(100, 1), - labels={1: {'name': 'output', 'dof': ['u']}} - ) - ) - } -class NeuralOperatorProblemGraph(AbstractProblem): - input_variables = ['x', 'y', 'u_0', 'u_1'] - output_variables = ['u'] - domains = { - 'pts': LabelTensor( - torch.rand(100, 4), - labels={1: {'name': 'space', 'dof': ['x', 'y', 'u_0', 'u_1']}} - ) - } conditions = { - 'data' : Condition( - domain='pts', - output_points=LabelTensor( - torch.rand(100, 1), - labels={1: {'name': 'output', 'dof': ['u']}} - ) - ) + 'data': Condition(input_points=in_, output_points=out_), } + class myFeature(torch.nn.Module): """ Feature: sin(x) @@ -61,117 +38,106 @@ def forward(self, x): problem = NeuralOperatorProblem() -problem_graph = NeuralOperatorProblemGraph() -# make the problem + extra feats extra_feats = [myFeature()] -model = FeedForward(len(problem.input_variables), - len(problem.output_variables)) +model = FeedForward(len(problem.input_variables), len(problem.output_variables)) model_extra_feats = FeedForward( - len(problem.input_variables) + 1, - len(problem.output_variables)) + len(problem.input_variables) + 1, len(problem.output_variables)) def test_constructor(): SupervisedSolver(problem=problem, model=model) -# def test_constructor_extra_feats(): -# SupervisedSolver(problem=problem, model=model_extra_feats, extra_features=extra_feats) - -''' -class AutoSolver(SupervisedSolver): - - def forward(self, input): - from pina.graph import Graph - print(Graph) - print(input) - if not isinstance(input, Graph): - input = Graph.build('radius', nodes_coordinates=input, nodes_data=torch.rand(input.shape), radius=0.2) - print(input) - print(input.data.edge_index) - print(input.data) - g = self._model(input.data, edge_index=input.data.edge_index) - g.labels = {1: {'name': 'output', 'dof': ['u']}} - return g - du_dt_new = LabelTensor(self.model(graph).reshape(-1,1), labels = ['du']) - - return du_dt_new -''' - -class GraphModel(torch.nn.Module): - def __init__(self, in_channels, out_channels): - from torch_geometric.nn import GCNConv, NNConv - super().__init__() - self.conv1 = GCNConv(in_channels, 16) - self.conv2 = GCNConv(16, out_channels) - - def forward(self, data, edge_index): - print(data) - x = data.x - print(x) - x = self.conv1(x, edge_index) - x = x.relu() - x = self.conv2(x, edge_index) - return x - -def test_graph(): - solver = GraphSupervisedSolver(problem=problem_graph, model=GraphModel(2, 1), loss=LpLoss(), - nodes_coordinates=['x', 'y'], nodes_data=['u_0', 'u_1']) - trainer = Trainer(solver=solver, max_epochs=30, accelerator='cpu', batch_size=20) - trainer.train() +test_constructor() + + +def laplace_equation(input_, output_): + force_term = (torch.sin(input_.extract(['x']) * torch.pi) * + torch.sin(input_.extract(['y']) * torch.pi)) + delta_u = laplacian(output_.extract(['u']), input_) + return delta_u - force_term + + +my_laplace = Equation(laplace_equation) + + +class Poisson(SpatialProblem): + output_variables = ['u'] + spatial_domain = CartesianDomain({'x': [0, 1], 'y': [0, 1]}) + + conditions = { + 'gamma1': + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': 1 + }), + equation=FixedValue(0.0)), + 'gamma2': + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': 0 + }), + equation=FixedValue(0.0)), + 'gamma3': + Condition(domain=CartesianDomain({ + 'x': 1, + 'y': [0, 1] + }), + equation=FixedValue(0.0)), + 'gamma4': + Condition(domain=CartesianDomain({ + 'x': 0, + 'y': [0, 1] + }), + equation=FixedValue(0.0)), + 'D': + Condition(domain=CartesianDomain({ + 'x': [0, 1], + 'y': [0, 1] + }), + equation=my_laplace), + 'data': + Condition(input_points=in_, output_points=out_) + } + + def poisson_sol(self, pts): + return -(torch.sin(pts.extract(['x']) * torch.pi) * + torch.sin(pts.extract(['y']) * torch.pi)) / (2 * torch.pi ** 2) + + truth_solution = poisson_sol + + +def test_wrong_constructor(): + poisson_problem = Poisson() + with pytest.raises(ValueError): + SupervisedSolver(problem=poisson_problem, model=model) def test_train_cpu(): - solver = SupervisedSolver(problem = problem, model=model, loss=LpLoss()) - trainer = Trainer(solver=solver, max_epochs=300, accelerator='cpu', batch_size=20) + solver = SupervisedSolver(problem=problem, model=model) + trainer = Trainer(solver=solver, + max_epochs=200, + accelerator='gpu', + batch_size=5, + train_size=1, + test_size=0., + eval_size=0.) trainer.train() +test_train_cpu() -# def test_train_restore(): -# tmpdir = "tests/tmp_restore" -# solver = SupervisedSolver(problem=problem, -# model=model, -# extra_features=None, -# loss=LpLoss()) -# trainer = Trainer(solver=solver, -# max_epochs=5, -# accelerator='cpu', -# default_root_dir=tmpdir) -# trainer.train() -# ntrainer = Trainer(solver=solver, max_epochs=15, accelerator='cpu') -# t = ntrainer.train( -# ckpt_path=f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=4-step=5.ckpt') -# import shutil -# shutil.rmtree(tmpdir) - - -# def test_train_load(): -# tmpdir = "tests/tmp_load" -# solver = SupervisedSolver(problem=problem, -# model=model, -# extra_features=None, -# loss=LpLoss()) -# trainer = Trainer(solver=solver, -# max_epochs=15, -# accelerator='cpu', -# default_root_dir=tmpdir) -# trainer.train() -# new_solver = SupervisedSolver.load_from_checkpoint( -# f'{tmpdir}/lightning_logs/version_0/checkpoints/epoch=14-step=15.ckpt', -# problem = problem, model=model) -# test_pts = LabelTensor(torch.rand(20, 2), problem.input_variables) -# assert new_solver.forward(test_pts).shape == (20, 1) -# assert new_solver.forward(test_pts).shape == solver.forward(test_pts).shape -# torch.testing.assert_close( -# new_solver.forward(test_pts), -# solver.forward(test_pts)) -# import shutil -# shutil.rmtree(tmpdir) - -# def test_train_extra_feats_cpu(): -# pinn = SupervisedSolver(problem=problem, -# model=model_extra_feats, -# extra_features=extra_feats) -# trainer = Trainer(solver=pinn, max_epochs=5, accelerator='cpu') -# trainer.train() -test_graph() \ No newline at end of file +def test_extra_features_constructor(): + SupervisedSolver(problem=problem, + model=model_extra_feats, + extra_features=extra_feats) + + +def test_extra_features_train_cpu(): + solver = SupervisedSolver(problem=problem, + model=model_extra_feats, + extra_features=extra_feats) + trainer = Trainer(solver=solver, + max_epochs=200, + accelerator='gpu', + batch_size=5) + trainer.train()