From ea5771836148c5e9a430742698e2d94075647b8a Mon Sep 17 00:00:00 2001 From: Dario Coscia <93731561+dario-coscia@users.noreply.github.com> Date: Wed, 27 Nov 2024 19:58:17 +0100 Subject: [PATCH] Revert "Fix bugs (#385)" This reverts commit 69cd0ed8cda91c92dab6551a0c6dfd94d199cee7. --- pina/label_tensor.py | 2 +- pina/solvers/__init__.py | 1 + pina/solvers/pinns/basepinn.py | 5 ++++- pina/solvers/pinns/pinn.py | 2 +- pina/solvers/solver.py | 10 ++++------ pina/solvers/supervised.py | 2 ++ pina/trainer.py | 1 + 7 files changed, 14 insertions(+), 9 deletions(-) diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 631c5253..a3cf5d23 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -4,7 +4,7 @@ from torch import Tensor -full_labels = False +full_labels = True MATH_FUNCTIONS = {torch.sin, torch.cos} class LabelTensor(torch.Tensor): diff --git a/pina/solvers/__init__.py b/pina/solvers/__init__.py index 6b755661..c6f53c78 100644 --- a/pina/solvers/__init__.py +++ b/pina/solvers/__init__.py @@ -16,3 +16,4 @@ from .supervised import SupervisedSolver from .rom import ReducedOrderModelSolver from .garom import GAROM +from .graph import GraphSupervisedSolver diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index dab21248..588d7314 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -3,10 +3,12 @@ from abc import ABCMeta, abstractmethod import torch from torch.nn.modules.loss import _Loss +from ...condition import InputOutputPointsCondition from ...solvers.solver import SolverInterface from ...utils import check_consistency from ...loss.loss_interface import LossInterface from ...problem import InverseProblem +from ...condition import DomainEquationCondition from ...optim import TorchOptimizer, TorchScheduler torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 @@ -24,7 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): to the user to choose which problem the implemented solver inheriting from this class is suitable for. """ - + accepted_condition_types = [DomainEquationCondition.condition_type[0], + InputOutputPointsCondition.condition_type[0]] def __init__( self, models, diff --git a/pina/solvers/pinns/pinn.py b/pina/solvers/pinns/pinn.py index d1ab21d7..08882020 100644 --- a/pina/solvers/pinns/pinn.py +++ b/pina/solvers/pinns/pinn.py @@ -11,7 +11,7 @@ from .basepinn import PINNInterface -from ...problem import InverseProblem +from pina.problem import InverseProblem class PINN(PINNInterface): diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index 8052b4b8..3a8f400c 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -134,18 +134,16 @@ def on_train_start(self): return super().on_train_start() def _check_solver_consistency(self, problem): - pass - #TODO : Implement this method for the conditions - ''' - - + """ + TODO + """ for _, condition in problem.conditions.items(): if not set(condition.condition_type).issubset( set(self.accepted_condition_types)): raise ValueError( f'{self.__name__} dose not support condition ' f'{condition.condition_type}') - ''' + @staticmethod def get_batch_size(batch): # Assuming batch is your custom Batch object diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index bce4b31a..947ab3b2 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -7,6 +7,7 @@ from ..label_tensor import LabelTensor from ..utils import check_consistency from ..loss.loss_interface import LossInterface +from ..condition import InputOutputPointsCondition class SupervisedSolver(SolverInterface): @@ -37,6 +38,7 @@ class SupervisedSolver(SolverInterface): we are seeking to approximate multiple (discretised) functions given multiple (discretised) input functions. """ + accepted_condition_types = [InputOutputPointsCondition.condition_type[0]] __name__ = 'SupervisedSolver' def __init__(self, diff --git a/pina/trainer.py b/pina/trainer.py index f8bccd8c..a7c5c351 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,4 +1,5 @@ """ Trainer module. """ +import warnings import torch import lightning from .utils import check_consistency