Skip to content

Commit

Permalink
Revert "Fix bugs (#385)"
Browse files Browse the repository at this point in the history
This reverts commit 69cd0ed.
  • Loading branch information
dario-coscia authored Nov 27, 2024
1 parent 69cd0ed commit ea57718
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 9 deletions.
2 changes: 1 addition & 1 deletion pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import Tensor


full_labels = False
full_labels = True
MATH_FUNCTIONS = {torch.sin, torch.cos}

class LabelTensor(torch.Tensor):
Expand Down
1 change: 1 addition & 0 deletions pina/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
from .supervised import SupervisedSolver
from .rom import ReducedOrderModelSolver
from .garom import GAROM
from .graph import GraphSupervisedSolver
5 changes: 4 additions & 1 deletion pina/solvers/pinns/basepinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from abc import ABCMeta, abstractmethod
import torch
from torch.nn.modules.loss import _Loss
from ...condition import InputOutputPointsCondition
from ...solvers.solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...condition import DomainEquationCondition
from ...optim import TorchOptimizer, TorchScheduler

torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
Expand All @@ -24,7 +26,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
to the user to choose which problem the implemented solver inheriting from
this class is suitable for.
"""

accepted_condition_types = [DomainEquationCondition.condition_type[0],
InputOutputPointsCondition.condition_type[0]]
def __init__(
self,
models,
Expand Down
2 changes: 1 addition & 1 deletion pina/solvers/pinns/pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


from .basepinn import PINNInterface
from ...problem import InverseProblem
from pina.problem import InverseProblem


class PINN(PINNInterface):
Expand Down
10 changes: 4 additions & 6 deletions pina/solvers/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,16 @@ def on_train_start(self):
return super().on_train_start()

def _check_solver_consistency(self, problem):
pass
#TODO : Implement this method for the conditions
'''
"""
TODO
"""
for _, condition in problem.conditions.items():
if not set(condition.condition_type).issubset(
set(self.accepted_condition_types)):
raise ValueError(
f'{self.__name__} dose not support condition '
f'{condition.condition_type}')
'''

@staticmethod
def get_batch_size(batch):
# Assuming batch is your custom Batch object
Expand Down
2 changes: 2 additions & 0 deletions pina/solvers/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..label_tensor import LabelTensor
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition


class SupervisedSolver(SolverInterface):
Expand Down Expand Up @@ -37,6 +38,7 @@ class SupervisedSolver(SolverInterface):
we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions.
"""
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
__name__ = 'SupervisedSolver'

def __init__(self,
Expand Down
1 change: 1 addition & 0 deletions pina/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Trainer module. """
import warnings
import torch
import lightning
from .utils import check_consistency
Expand Down

0 comments on commit ea57718

Please sign in to comment.