From e0e9364d667cab3af12009cc9a01c4fe0ddba519 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 28 Nov 2024 13:19:33 +0100 Subject: [PATCH] Add check conditions-solver compatibility --- pina/solvers/pinns/basepinn.py | 4 ++++ pina/solvers/solver.py | 17 +++-------------- pina/solvers/supervised.py | 4 +++- 3 files changed, 10 insertions(+), 15 deletions(-) diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index dab21248..c52b2027 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -3,6 +3,8 @@ from abc import ABCMeta, abstractmethod import torch from torch.nn.modules.loss import _Loss +from ...condition import InputOutputPointsCondition, \ + InputPointsEquationCondition, DomainEquationCondition from ...solvers.solver import SolverInterface from ...utils import check_consistency from ...loss.loss_interface import LossInterface @@ -24,6 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): to the user to choose which problem the implemented solver inheriting from this class is suitable for. """ + accepted_conditions_types = (InputOutputPointsCondition, + InputPointsEquationCondition, DomainEquationCondition) def __init__( self, diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index 8052b4b8..10f62aa9 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -9,8 +9,6 @@ import torch import sys - - class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ Solver base class. This class inherits is a wrapper of @@ -134,18 +132,9 @@ def on_train_start(self): return super().on_train_start() def _check_solver_consistency(self, problem): - pass - #TODO : Implement this method for the conditions - ''' - - - for _, condition in problem.conditions.items(): - if not set(condition.condition_type).issubset( - set(self.accepted_condition_types)): - raise ValueError( - f'{self.__name__} dose not support condition ' - f'{condition.condition_type}') - ''' + for condition in problem.conditions.values(): + check_consistency(condition, self.accepted_conditions_types) + @staticmethod def get_batch_size(batch): # Assuming batch is your custom Batch object diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index bce4b31a..ab3207ac 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -2,6 +2,8 @@ import torch from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.nn.modules.loss import _Loss + +from ..condition import InputOutputPointsCondition from ..optim import TorchOptimizer, TorchScheduler from .solver import SolverInterface from ..label_tensor import LabelTensor @@ -37,7 +39,7 @@ class SupervisedSolver(SolverInterface): we are seeking to approximate multiple (discretised) functions given multiple (discretised) input functions. """ - __name__ = 'SupervisedSolver' + accepted_conditions_types = InputOutputPointsCondition def __init__(self, problem,