Skip to content

Commit

Permalink
Fix bugs (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoOlivo authored Nov 27, 2024
1 parent 8c218c4 commit 69cd0ed
Show file tree
Hide file tree
Showing 7 changed files with 9 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import Tensor


full_labels = True
full_labels = False
MATH_FUNCTIONS = {torch.sin, torch.cos}

class LabelTensor(torch.Tensor):
Expand Down
1 change: 0 additions & 1 deletion pina/solvers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@
from .supervised import SupervisedSolver
from .rom import ReducedOrderModelSolver
from .garom import GAROM
from .graph import GraphSupervisedSolver
5 changes: 1 addition & 4 deletions pina/solvers/pinns/basepinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
from abc import ABCMeta, abstractmethod
import torch
from torch.nn.modules.loss import _Loss
from ...condition import InputOutputPointsCondition
from ...solvers.solver import SolverInterface
from ...utils import check_consistency
from ...loss.loss_interface import LossInterface
from ...problem import InverseProblem
from ...condition import DomainEquationCondition
from ...optim import TorchOptimizer, TorchScheduler

torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
Expand All @@ -26,8 +24,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
to the user to choose which problem the implemented solver inheriting from
this class is suitable for.
"""
accepted_condition_types = [DomainEquationCondition.condition_type[0],
InputOutputPointsCondition.condition_type[0]]

def __init__(
self,
models,
Expand Down
2 changes: 1 addition & 1 deletion pina/solvers/pinns/pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


from .basepinn import PINNInterface
from pina.problem import InverseProblem
from ...problem import InverseProblem


class PINN(PINNInterface):
Expand Down
10 changes: 6 additions & 4 deletions pina/solvers/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,18 @@ def on_train_start(self):
return super().on_train_start()

def _check_solver_consistency(self, problem):
"""
TODO
"""
pass
#TODO : Implement this method for the conditions
'''
for _, condition in problem.conditions.items():
if not set(condition.condition_type).issubset(
set(self.accepted_condition_types)):
raise ValueError(
f'{self.__name__} dose not support condition '
f'{condition.condition_type}')

'''
@staticmethod
def get_batch_size(batch):
# Assuming batch is your custom Batch object
Expand Down
2 changes: 0 additions & 2 deletions pina/solvers/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ..label_tensor import LabelTensor
from ..utils import check_consistency
from ..loss.loss_interface import LossInterface
from ..condition import InputOutputPointsCondition


class SupervisedSolver(SolverInterface):
Expand Down Expand Up @@ -38,7 +37,6 @@ class SupervisedSolver(SolverInterface):
we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions.
"""
accepted_condition_types = [InputOutputPointsCondition.condition_type[0]]
__name__ = 'SupervisedSolver'

def __init__(self,
Expand Down
1 change: 0 additions & 1 deletion pina/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
""" Trainer module. """
import warnings
import torch
import lightning
from .utils import check_consistency
Expand Down

0 comments on commit 69cd0ed

Please sign in to comment.