Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Version 0.2 #359

Merged
merged 14 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions pina/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
__all__ = [
"PINN",
"Trainer",
"LabelTensor",
"Plotter",
"Condition",
"SamplePointDataset",
"SamplePointLoader",
"Trainer", "LabelTensor", "Plotter", "Condition", "SamplePointDataset",
"PinaDataModule", "PinaDataLoader", 'TorchOptimizer', 'Graph'
]

from .meta import *
Expand All @@ -15,4 +10,8 @@
from .plotter import Plotter
from .condition.condition import Condition
from .data import SamplePointDataset
from .data import SamplePointLoader
from .data import PinaDataModule
from .data import PinaDataLoader
from .optim import TorchOptimizer
from .optim import TorchScheduler
from .graph import Graph
114 changes: 114 additions & 0 deletions pina/collector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from .utils import check_consistency, merge_tensors


class Collector:

def __init__(self, problem):
# creating a hook between collector and problem
self.problem = problem

# this variable is used to store the data in the form:
# {'[condition_name]' :
# {'input_points' : Tensor,
# '[equation/output_points/conditional_variables]': Tensor}
# }
# those variables are used for the dataloading
self._data_collections = {name: {} for name in self.problem.conditions}
self.conditions_name = {
i: name
for i, name in enumerate(self.problem.conditions)
}

# variables used to check that all conditions are sampled
self._is_conditions_ready = {
name: False
for name in self.problem.conditions
}
self.full = False

@property
def full(self):
return all(self._is_conditions_ready.values())

@full.setter
def full(self, value):
check_consistency(value, bool)
self._full = value

@property
def data_collections(self):
return self._data_collections

@property
def problem(self):
return self._problem

@problem.setter
def problem(self, value):
self._problem = value

def store_fixed_data(self):
# loop over all conditions
for condition_name, condition in self.problem.conditions.items():
# if the condition is not ready and domain is not attribute
# of condition, we get and store the data
if (not self._is_conditions_ready[condition_name]) and (not hasattr(
condition, "domain")):
# get data
keys = condition.__slots__
values = [getattr(condition, name) for name in keys]
self.data_collections[condition_name] = dict(zip(keys, values))
# condition now is ready
self._is_conditions_ready[condition_name] = True

def store_sample_domains(self, n, mode, variables, sample_locations):
# loop over all locations
for loc in sample_locations:
# get condition
condition = self.problem.conditions[loc]
keys = ["input_points", "equation"]
# if the condition is not ready, we get and store the data
if (not self._is_conditions_ready[loc]):
# if it is the first time we sample
if not self.data_collections[loc]:
already_sampled = []
# if we have sampled the condition but not all variables
else:
already_sampled = [
self.data_collections[loc]['input_points']
]
# if the condition is ready but we want to sample again
else:
self._is_conditions_ready[loc] = False
already_sampled = []

# get the samples
samples = [
condition.domain.sample(n=n, mode=mode, variables=variables)
] + already_sampled
pts = merge_tensors(samples)
if (set(pts.labels).issubset(sorted(self.problem.input_variables))):
pts = pts.sort_labels()
if sorted(pts.labels) == sorted(self.problem.input_variables):
self._is_conditions_ready[loc] = True
values = [pts, condition.equation]
self.data_collections[loc] = dict(zip(keys, values))
else:
raise RuntimeError(
'Try to sample variables which are not in problem defined '
'in the problem')

def add_points(self, new_points_dict):
"""
Add input points to a sampled condition

:param new_points_dict: Dictonary of input points (condition_name:
LabelTensor)
:raises RuntimeError: if at least one condition is not already sampled
"""
for k, v in new_points_dict.items():
if not self._is_conditions_ready[k]:
raise RuntimeError(
'Cannot add points on a non sampled condition')
self.data_collections[k]['input_points'] = self.data_collections[k][
'input_points'].vstack(v)
10 changes: 6 additions & 4 deletions pina/condition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
__all__ = [
'Condition',
'ConditionInterface',
'DomainOutputCondition',
'DomainEquationCondition'
'DomainEquationCondition',
'InputPointsEquationCondition',
'InputOutputPointsCondition',
]

from .condition_interface import ConditionInterface
from .domain_output_condition import DomainOutputCondition
from .domain_equation_condition import DomainEquationCondition
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition
112 changes: 35 additions & 77 deletions pina/condition/condition.py
Original file line number Diff line number Diff line change
@@ -1,107 +1,65 @@
""" Condition module. """

from ..label_tensor import LabelTensor
from ..domain import DomainInterface
from ..equation.equation import Equation

from . import DomainOutputCondition, DomainEquationCondition


def dummy(a):
"""Dummy function for testing purposes."""
return None
from .domain_equation_condition import DomainEquationCondition
from .input_equation_condition import InputPointsEquationCondition
from .input_output_condition import InputOutputPointsCondition
from .data_condition import DataConditionInterface


class Condition:
"""
The class ``Condition`` is used to represent the constraints (physical
equations, boundary conditions, etc.) that should be satisfied in the
problem at hand. Condition objects are used to formulate the PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
Conditions can be specified in three ways:
problem at hand. Condition objects are used to formulate the
PINA :obj:`pina.problem.abstract_problem.AbstractProblem` object.
Conditions can be specified in four ways:

1. By specifying the input and output points of the condition; in such a
case, the model is trained to produce the output points given the input
points.
points. Those points can either be torch.Tensor, LabelTensors, Graph

2. By specifying the location and the equation of the condition; in such
a case, the model is trained to minimize the equation residual by
evaluating it at some samples of the location.

3. By specifying the input points and the equation of the condition; in
such a case, the model is trained to minimize the equation residual by
evaluating it at the passed input points.
evaluating it at the passed input points. The input points must be
a LabelTensor.

4. By specifying only the data matrix; in such a case the model is
trained with an unsupervised costum loss and uses the data in training.
Additionaly conditioning variables can be passed, whenever the model
has extra conditioning variable it depends on.

Example::

>>> example_domain = Span({'x': [0, 1], 'y': [0, 1]})
>>> def example_dirichlet(input_, output_):
>>> value = 0.0
>>> return output_.extract(['u']) - value
>>> example_input_pts = LabelTensor(
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
>>>
>>> Condition(
>>> input_points=example_input_pts,
>>> output_points=example_output_pts)
>>> Condition(
>>> location=example_domain,
>>> equation=example_dirichlet)
>>> Condition(
>>> input_points=example_input_pts,
>>> equation=example_dirichlet)
>>> TODO

"""

# def _dictvalue_isinstance(self, dict_, key_, class_):
# """Check if the value of a dictionary corresponding to `key` is an instance of `class_`."""
# if key_ not in dict_.keys():
# return True

# return isinstance(dict_[key_], class_)

# def __init__(self, *args, **kwargs):
# """
# Constructor for the `Condition` class.
# """
# self.data_weight = kwargs.pop("data_weight", 1.0)

# if len(args) != 0:
# raise ValueError(
# f"Condition takes only the following keyword arguments: {Condition.__slots__}."
# )
__slots__ = list(
set(InputOutputPointsCondition.__slots__ +
InputPointsEquationCondition.__slots__ +
DomainEquationCondition.__slots__ +
DataConditionInterface.__slots__))

def __new__(cls, *args, **kwargs):

if sorted(kwargs.keys()) == sorted(["input_points", "output_points"]):
return DomainOutputCondition(
domain=kwargs["input_points"],
output_points=kwargs["output_points"]
)
elif sorted(kwargs.keys()) == sorted(["domain", "output_points"]):
return DomainOutputCondition(**kwargs)
elif sorted(kwargs.keys()) == sorted(["domain", "equation"]):
if len(args) != 0:
raise ValueError("Condition takes only the following keyword "
f"arguments: {Condition.__slots__}.")

sorted_keys = sorted(kwargs.keys())
if sorted_keys == sorted(InputOutputPointsCondition.__slots__):
return InputOutputPointsCondition(**kwargs)
elif sorted_keys == sorted(InputPointsEquationCondition.__slots__):
return InputPointsEquationCondition(**kwargs)
elif sorted_keys == sorted(DomainEquationCondition.__slots__):
return DomainEquationCondition(**kwargs)
elif sorted_keys == sorted(DataConditionInterface.__slots__):
return DataConditionInterface(**kwargs)
elif sorted_keys == DataConditionInterface.__slots__[0]:
return DataConditionInterface(**kwargs)
else:
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
'''
if (
sorted(kwargs.keys()) != sorted(["input_points", "output_points"])
and sorted(kwargs.keys()) != sorted(["location", "equation"])
and sorted(kwargs.keys()) != sorted(["input_points", "equation"])
):
raise ValueError(f"Invalid keyword arguments {kwargs.keys()}.")
# TODO: remove, not used anymore
if not self._dictvalue_isinstance(kwargs, "input_points", LabelTensor):
raise TypeError("`input_points` must be a torch.Tensor.")
if not self._dictvalue_isinstance(kwargs, "output_points", LabelTensor):
raise TypeError("`output_points` must be a torch.Tensor.")
if not self._dictvalue_isinstance(kwargs, "location", Location):
raise TypeError("`location` must be a Location.")
if not self._dictvalue_isinstance(kwargs, "equation", Equation):
raise TypeError("`equation` must be a Equation.")

for key, value in kwargs.items():
setattr(self, key, value)
'''
38 changes: 25 additions & 13 deletions pina/condition/condition_interface.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@

from abc import ABCMeta, abstractmethod
from abc import ABCMeta


class ConditionInterface(metaclass=ABCMeta):

def __init__(self) -> None:
condition_types = ['physics', 'supervised', 'unsupervised']

def __init__(self, *args, **kwargs):
self._condition_type = None
self._problem = None

@abstractmethod
def residual(self, model):
"""
Compute the residual of the condition.
@property
def problem(self):
return self._problem

@problem.setter
def problem(self, value):
self._problem = value

:param model: The model to evaluate the condition.
:return: The residual of the condition.
"""
pass
@property
def condition_type(self):
return self._condition_type

def set_problem(self, problem):
self._problem = problem
@condition_type.setter
def condition_type(self, values):
if not isinstance(values, (list, tuple)):
values = [values]
for value in values:
if value not in ConditionInterface.condition_types:
raise ValueError(
'Unavailable type of condition, expected one of'
f' {ConditionInterface.condition_types}.')
self._condition_type = values
33 changes: 33 additions & 0 deletions pina/condition/data_condition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch

from . import ConditionInterface
from ..label_tensor import LabelTensor
from ..graph import Graph
from ..utils import check_consistency


class DataConditionInterface(ConditionInterface):
"""
Condition for data. This condition must be used every
time a Unsupervised Loss is needed in the Solver. The conditionalvariable
can be passed as extra-input when the model learns a conditional
distribution
"""

__slots__ = ["input_points", "conditional_variables"]

def __init__(self, input_points, conditional_variables=None):
"""
TODO
"""
super().__init__()
self.input_points = input_points
self.conditional_variables = conditional_variables
self._condition_type = 'unsupervised'

def __setattr__(self, key, value):
if (key == 'input_points') or (key == 'conditional_variables'):
check_consistency(value, (LabelTensor, Graph, torch.Tensor))
DataConditionInterface.__dict__[key].__set__(self, value)
elif key in ('_problem', '_condition_type'):
super().__setattr__(key, value)
Loading
Loading