From 9463ae4b15c6a633c3a803bfa0012272a581ceea Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 10 May 2024 14:08:01 +0200 Subject: [PATCH] :art: Format Python code with psf/black (#297) Co-authored-by: dario-coscia --- pina/solvers/__init__.py | 3 +- pina/solvers/pinns/basepinn.py | 33 +++---- pina/solvers/pinns/causalpinn.py | 34 +++---- pina/solvers/pinns/competitive_pinn.py | 39 +++++---- pina/solvers/pinns/gpinn.py | 31 +++---- pina/solvers/pinns/pinn.py | 7 +- pina/solvers/pinns/sapinn.py | 117 ++++++++++++------------- pina/solvers/rom.py | 51 ++++++----- pina/solvers/solver.py | 4 +- pina/solvers/supervised.py | 4 +- pina/trainer.py | 6 +- 11 files changed, 169 insertions(+), 160 deletions(-) diff --git a/pina/solvers/__init__.py b/pina/solvers/__init__.py index 2751e481..6b755661 100644 --- a/pina/solvers/__init__.py +++ b/pina/solvers/__init__.py @@ -9,11 +9,10 @@ "SupervisedSolver", "ReducedOrderModelSolver", "GAROM", - ] +] from .solver import SolverInterface from .pinns import * from .supervised import SupervisedSolver from .rom import ReducedOrderModelSolver from .garom import GAROM - diff --git a/pina/solvers/pinns/basepinn.py b/pina/solvers/pinns/basepinn.py index 726cdf92..53d4d3a9 100644 --- a/pina/solvers/pinns/basepinn.py +++ b/pina/solvers/pinns/basepinn.py @@ -12,11 +12,12 @@ torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 + class PINNInterface(SolverInterface, metaclass=ABCMeta): """ Base PINN solver class. This class implements the Solver Interface for Physics Informed Neural Network solvers. - + This class can be used to define PINNs with multiple ``optimizers``, and/or ``models``. By default it takes @@ -72,7 +73,7 @@ def __init__( self._clamp_params = self._clamp_inverse_problem_params else: self._params = None - self._clamp_params = lambda : None + self._clamp_params = lambda: None # variable used internally to store residual losses at each epoch # this variable save the residual at each iteration (not weighted) @@ -107,7 +108,7 @@ def training_step(self, batch, _): condition = self.problem.conditions[condition_name] pts = batch["pts"] # condition name is logged (if logs enabled) - self.__logged_metric = condition_name + self.__logged_metric = condition_name if len(batch) == 2: samples = pts[condition_idx == condition_id] @@ -160,7 +161,7 @@ def loss_phys(self, samples, equation): :rtype: LabelTensor """ pass - + def compute_residual(self, samples, equation): """ Compute the residual for Physics Informed learning. This function @@ -182,7 +183,7 @@ def compute_residual(self, samples, equation): samples, self.forward(samples), self._params ) return residual - + def store_log(self, loss_value): """ Stores the loss value in the logger. This function should be @@ -195,13 +196,13 @@ def store_log(self, loss_value): :param torch.Tensor loss_value: The value of the loss. """ self.log( - self.__logged_metric+'_loss', - loss_value, - prog_bar=True, - logger=True, - on_epoch=True, - on_step=False, - ) + self.__logged_metric + "_loss", + loss_value, + prog_bar=True, + logger=True, + on_epoch=True, + on_step=False, + ) self.__logged_res_losses.append(loss_value) def on_train_epoch_end(self): @@ -211,10 +212,10 @@ def on_train_epoch_end(self): """ if self.__logged_res_losses: # storing mean loss - self.__logged_metric = 'mean' + self.__logged_metric = "mean" self.store_log( - sum(self.__logged_res_losses)/len(self.__logged_res_losses) - ) + sum(self.__logged_res_losses) / len(self.__logged_res_losses) + ) # free the logged losses self.__logged_res_losses = [] return super().on_train_epoch_end() @@ -244,4 +245,4 @@ def current_condition_name(self): :meth:`loss_phys` to extract the condition at which the loss is computed. """ - return self.__logged_metric \ No newline at end of file + return self.__logged_metric diff --git a/pina/solvers/pinns/causalpinn.py b/pina/solvers/pinns/causalpinn.py index fea0fe47..476e4c55 100644 --- a/pina/solvers/pinns/causalpinn.py +++ b/pina/solvers/pinns/causalpinn.py @@ -97,25 +97,27 @@ def __init__( :param dict scheduler_kwargs: LR scheduler constructor keyword args. :param int | float eps: The exponential decay parameter. Note that this value is kept fixed during the training, but can be changed by means - of a callback, e.g. for annealing. + of a callback, e.g. for annealing. """ super().__init__( - problem=problem, - model=model, - extra_features=extra_features, - loss=loss, - optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + problem=problem, + model=model, + extra_features=extra_features, + loss=loss, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, ) # checking consistency - check_consistency(eps, (int,float)) + check_consistency(eps, (int, float)) self._eps = eps if not isinstance(self.problem, TimeDependentProblem): - raise ValueError('Casual PINN works only for problems' - 'inheritig from TimeDependentProblem.') + raise ValueError( + "Casual PINN works only for problems" + "inheritig from TimeDependentProblem." + ) def loss_phys(self, samples, equation): """ @@ -144,14 +146,14 @@ def loss_phys(self, samples, equation): ) time_loss.append(loss_val) # store results - self.store_log(loss_value=float(sum(time_loss)/len(time_loss))) + self.store_log(loss_value=float(sum(time_loss) / len(time_loss))) # concatenate residuals time_loss = torch.stack(time_loss) # compute weights (without the gradient storing) with torch.no_grad(): weights = self._compute_weights(time_loss) return (weights * time_loss).mean() - + @property def eps(self): """ @@ -205,8 +207,8 @@ def _split_tensor_into_chunks(self, tensor): _, idx_split = time_tensor.unique(return_counts=True) # splitting chunks = torch.split(tensor, tuple(idx_split)) - return chunks, labels # return chunks - + return chunks, labels # return chunks + def _compute_weights(self, loss): """ Computes the weights for the physics loss based on the cumulative loss. diff --git a/pina/solvers/pinns/competitive_pinn.py b/pina/solvers/pinns/competitive_pinn.py index 6404c0bb..5e011a47 100644 --- a/pina/solvers/pinns/competitive_pinn.py +++ b/pina/solvers/pinns/competitive_pinn.py @@ -117,7 +117,7 @@ def __init__( optimizer_discriminator_kwargs, ], extra_features=None, # CompetitivePINN doesn't take extra features - loss=loss + loss=loss, ) # set automatic optimization for GANs @@ -131,9 +131,7 @@ def __init__( # assign schedulers self._schedulers = [ - scheduler_model( - self.optimizers[0], **scheduler_model_kwargs - ), + scheduler_model(self.optimizers[0], **scheduler_model_kwargs), scheduler_discriminator( self.optimizers[1], **scheduler_discriminator_kwargs ), @@ -141,7 +139,7 @@ def __init__( self._model = self.models[0] self._discriminator = self.models[1] - + def forward(self, x): r""" Forward pass implementation for the PINN solver. It returns the function @@ -195,8 +193,11 @@ def loss_data(self, input_tensor, output_tensor): :rtype: torch.Tensor """ self.optimizer_model.zero_grad() - loss_val = super().loss_data( - input_tensor, output_tensor).as_subclass(torch.Tensor) + loss_val = ( + super() + .loss_data(input_tensor, output_tensor) + .as_subclass(torch.Tensor) + ) loss_val.backward() self.optimizer_model.step() return loss_val @@ -221,7 +222,7 @@ def configure_optimizers(self): ) return self.optimizers, self._schedulers - def on_train_batch_end(self,outputs, batch, batch_idx): + def on_train_batch_end(self, outputs, batch, batch_idx): """ This method is called at the end of each training batch, and ovverides the PytorchLightining implementation for logging the checkpoints. @@ -235,7 +236,9 @@ def on_train_batch_end(self,outputs, batch, batch_idx): :rtype: Any """ # increase by one the counter of optimization to save loggers - self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1 + self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += ( + 1 + ) return super().on_train_batch_end(outputs, batch, batch_idx) def _train_discriminator(self, samples, equation, discriminator_bets): @@ -252,13 +255,14 @@ def _train_discriminator(self, samples, equation, discriminator_bets): self.optimizer_discriminator.zero_grad() # compute residual, we detach because the weights of the generator # model are fixed - residual = self.compute_residual(samples=samples, - equation=equation).detach() + residual = self.compute_residual( + samples=samples, equation=equation + ).detach() # compute competitive residual, the minus is because we maximise competitive_residual = residual * discriminator_bets - loss_val = - self.loss( + loss_val = -self.loss( torch.zeros_like(competitive_residual, requires_grad=True), - competitive_residual + competitive_residual, ).as_subclass(torch.Tensor) # backprop self.manual_backward(loss_val) @@ -283,16 +287,13 @@ def _train_model(self, samples, equation, discriminator_bets): residual = self.compute_residual(samples=samples, equation=equation) # store logging with torch.no_grad(): - loss_residual = self.loss( - torch.zeros_like(residual), - residual - ) + loss_residual = self.loss(torch.zeros_like(residual), residual) # compute competitive residual, discriminator_bets are detached becase # we optimize only the generator model competitive_residual = residual * discriminator_bets.detach() loss_val = self.loss( torch.zeros_like(competitive_residual, requires_grad=True), - competitive_residual + competitive_residual, ).as_subclass(torch.Tensor) # backprop self.manual_backward(loss_val) @@ -357,4 +358,4 @@ def scheduler_discriminator(self): :return: The scheduler for the discriminator. :rtype: torch.optim.lr_scheduler._LRScheduler """ - return self._schedulers[1] \ No newline at end of file + return self._schedulers[1] diff --git a/pina/solvers/pinns/gpinn.py b/pina/solvers/pinns/gpinn.py index 6eca1eac..5f259ca2 100644 --- a/pina/solvers/pinns/gpinn.py +++ b/pina/solvers/pinns/gpinn.py @@ -90,22 +90,23 @@ def __init__( :param dict scheduler_kwargs: LR scheduler constructor keyword args. """ super().__init__( - problem=problem, - model=model, - extra_features=extra_features, - loss=loss, - optimizer=optimizer, - optimizer_kwargs=optimizer_kwargs, - scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs, + problem=problem, + model=model, + extra_features=extra_features, + loss=loss, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + scheduler=scheduler, + scheduler_kwargs=scheduler_kwargs, ) if not isinstance(self.problem, SpatialProblem): - raise ValueError('Gradient PINN computes the gradient of the ' - 'PINN loss with respect to the spatial ' - 'coordinates, thus the PINA problem must be ' - 'a SpatialProblem.') + raise ValueError( + "Gradient PINN computes the gradient of the " + "PINN loss with respect to the spatial " + "coordinates, thus the PINA problem must be " + "a SpatialProblem." + ) - def loss_phys(self, samples, equation): """ Computes the physics loss for the GPINN solver based on given @@ -126,9 +127,9 @@ def loss_phys(self, samples, equation): self.store_log(loss_value=float(loss_value)) # gradient PINN loss loss_value = loss_value.reshape(-1, 1) - loss_value.labels = ['__LOSS'] + loss_value.labels = ["__LOSS"] loss_grad = grad(loss_value, samples, d=self.problem.spatial_variables) g_loss_phys = self.loss( torch.zeros_like(loss_grad, requires_grad=True), loss_grad ) - return loss_value + g_loss_phys \ No newline at end of file + return loss_value + g_loss_phys diff --git a/pina/solvers/pinns/pinn.py b/pina/solvers/pinns/pinn.py index 318283a3..15f90818 100644 --- a/pina/solvers/pinns/pinn.py +++ b/pina/solvers/pinns/pinn.py @@ -87,7 +87,7 @@ def __init__( optimizers=[optimizer], optimizers_kwargs=[optimizer_kwargs], extra_features=extra_features, - loss=loss + loss=loss, ) # check consistency @@ -131,7 +131,6 @@ def loss_phys(self, samples, equation): self.store_log(loss_value=float(loss_value)) return loss_value - def configure_optimizers(self): """ Optimizer configuration for the PINN @@ -153,7 +152,6 @@ def configure_optimizers(self): ) return self.optimizers, [self.scheduler] - @property def scheduler(self): """ @@ -161,10 +159,9 @@ def scheduler(self): """ return self._scheduler - @property def neural_net(self): """ Neural network for the PINN training. """ - return self._neural_net \ No newline at end of file + return self._neural_net diff --git a/pina/solvers/pinns/sapinn.py b/pina/solvers/pinns/sapinn.py index 8de2d14c..751e21ef 100644 --- a/pina/solvers/pinns/sapinn.py +++ b/pina/solvers/pinns/sapinn.py @@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import ConstantLR + class Weights(torch.nn.Module): """ This class aims to implements the mask model for @@ -27,11 +28,9 @@ def __init__(self, func): """ super().__init__() check_consistency(func, torch.nn.Module) - self.sa_weights = torch.nn.Parameter( - torch.Tensor() - ) + self.sa_weights = torch.nn.Parameter(torch.Tensor()) self.func = func - + def forward(self): """ Forward pass implementation for the mask module. @@ -43,6 +42,7 @@ def forward(self): """ return self.func(self.sa_weights) + class SAPINN(PINNInterface): r""" Self Adaptive Physics Informed Neural Network (SAPINN) solver class. @@ -106,22 +106,22 @@ class SAPINN(PINNInterface): DOI: `10.1016/ j.jcp.2022.111722 `_. """ - + def __init__( - self, - problem, - model, - weights_function=torch.nn.Sigmoid(), - extra_features=None, - loss=torch.nn.MSELoss(), - optimizer_model=torch.optim.Adam, - optimizer_model_kwargs={"lr" : 0.001}, - optimizer_weights=torch.optim.Adam, - optimizer_weights_kwargs={"lr" : 0.001}, - scheduler_model=ConstantLR, - scheduler_model_kwargs={"factor" : 1, "total_iters" : 0}, - scheduler_weights=ConstantLR, - scheduler_weights_kwargs={"factor" : 1, "total_iters" : 0} + self, + problem, + model, + weights_function=torch.nn.Sigmoid(), + extra_features=None, + loss=torch.nn.MSELoss(), + optimizer_model=torch.optim.Adam, + optimizer_model_kwargs={"lr": 0.001}, + optimizer_weights=torch.optim.Adam, + optimizer_weights_kwargs={"lr": 0.001}, + scheduler_model=ConstantLR, + scheduler_model_kwargs={"factor": 1, "total_iters": 0}, + scheduler_weights=ConstantLR, + scheduler_weights_kwargs={"factor": 1, "total_iters": 0}, ): """ :param AbstractProblem problem: The formualation of the problem. @@ -167,19 +167,18 @@ def __init__( weights_dict[condition_name] = Weights(weights_function) weights_dict = torch.nn.ModuleDict(weights_dict) - super().__init__( models=[model, weights_dict], problem=problem, optimizers=[optimizer_model, optimizer_weights], optimizers_kwargs=[ optimizer_model_kwargs, - optimizer_weights_kwargs + optimizer_weights_kwargs, ], extra_features=extra_features, - loss=loss + loss=loss, ) - + # set automatic optimization self.automatic_optimization = False @@ -191,12 +190,8 @@ def __init__( # assign schedulers self._schedulers = [ - scheduler_model( - self.optimizers[0], **scheduler_model_kwargs - ), - scheduler_weights( - self.optimizers[1], **scheduler_weights_kwargs - ), + scheduler_model(self.optimizers[0], **scheduler_model_kwargs), + scheduler_weights(self.optimizers[1], **scheduler_weights_kwargs), ] self._model = self.models[0] @@ -204,7 +199,7 @@ def __init__( self._vectorial_loss = deepcopy(loss) self._vectorial_loss.reduction = "none" - + def forward(self, x): """ Forward pass implementation for the PINN @@ -219,7 +214,7 @@ def forward(self, x): :rtype: LabelTensor """ return self.neural_net(x) - + def loss_phys(self, samples, equation): """ Computes the physics loss for the SAPINN solver based on given @@ -235,7 +230,7 @@ def loss_phys(self, samples, equation): # train weights self.optimizer_weights.zero_grad() weighted_loss, _ = self._loss_phys(samples, equation) - loss_value = - weighted_loss.as_subclass(torch.Tensor) + loss_value = -weighted_loss.as_subclass(torch.Tensor) self.manual_backward(loss_value) self.optimizer_weights.step() @@ -271,7 +266,7 @@ def loss_data(self, input_tensor, output_tensor): # train weights self.optimizer_weights.zero_grad() weighted_loss, _ = self._loss_data(input_tensor, output_tensor) - loss_value = - weighted_loss.as_subclass(torch.Tensor) + loss_value = -weighted_loss.as_subclass(torch.Tensor) self.manual_backward(loss_value) self.optimizer_weights.step() @@ -291,7 +286,7 @@ def loss_data(self, input_tensor, output_tensor): # store loss without weights self.store_log(loss_value=float(loss)) return loss_value - + def configure_optimizers(self): """ Optimizer configuration for the SAPINN @@ -312,8 +307,8 @@ def configure_optimizers(self): } ) return self.optimizers, self._schedulers - - def on_train_batch_end(self,outputs, batch, batch_idx): + + def on_train_batch_end(self, outputs, batch, batch_idx): """ This method is called at the end of each training batch, and ovverides the PytorchLightining implementation for logging the checkpoints. @@ -327,9 +322,11 @@ def on_train_batch_end(self,outputs, batch, batch_idx): :rtype: Any """ # increase by one the counter of optimization to save loggers - self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += 1 + self.trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.total.completed += ( + 1 + ) return super().on_train_batch_end(outputs, batch, batch_idx) - + def on_train_start(self): """ This method is called at the start of the training for setting @@ -343,12 +340,11 @@ def on_train_start(self): self.trainer._accelerator_connector._accelerator_flag ) for condition_name, tensor in self.problem.input_pts.items(): - self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand( - (tensor.shape[0], 1), - device = device + self.weights_dict.torchmodel[condition_name].sa_weights.data = ( + torch.rand((tensor.shape[0], 1), device=device) ) return super().on_train_start() - + def on_load_checkpoint(self, checkpoint): """ Overriding the Pytorch Lightning ``on_load_checkpoint`` to handle @@ -358,8 +354,8 @@ def on_load_checkpoint(self, checkpoint): :param dict checkpoint: Pytorch Lightning checkpoint dict. """ for condition_name, tensor in self.problem.input_pts.items(): - self.weights_dict.torchmodel[condition_name].sa_weights.data = torch.rand( - (tensor.shape[0], 1) + self.weights_dict.torchmodel[condition_name].sa_weights.data = ( + torch.rand((tensor.shape[0], 1)) ) return super().on_load_checkpoint(checkpoint) @@ -370,13 +366,13 @@ def _loss_phys(self, samples, equation): :param LabelTensor samples: Input samples to evaluate the physics loss. :param EquationInterface equation: the governing equation representing the physics. - + :return: tuple with weighted and not weighted scalar loss :rtype: List[LabelTensor, LabelTensor] """ residual = self.compute_residual(samples, equation) return self._compute_loss(residual) - + def _loss_data(self, input_tensor, output_tensor): """ Elaboration of the loss related to data for the SAPINN solver. @@ -384,7 +380,7 @@ def _loss_data(self, input_tensor, output_tensor): :param LabelTensor input_tensor: The input to the neural networks. :param LabelTensor output_tensor: The true solution to compare the network solution. - + :return: tuple with weighted and not weighted scalar loss :rtype: List[LabelTensor, LabelTensor] """ @@ -396,19 +392,21 @@ def _compute_loss(self, residual): Elaboration of the pointwise loss through the mask model and the self adaptive weights - :param LabelTensor residual: the matrix of residuals that have to + :param LabelTensor residual: the matrix of residuals that have to be weighted :return: tuple with weighted and not weighted loss :rtype List[LabelTensor, LabelTensor] """ weights = self.weights_dict.torchmodel[ - self.current_condition_name].forward() - loss_value = self._vectorial_loss(torch.zeros_like( - residual, requires_grad=True), residual) + self.current_condition_name + ].forward() + loss_value = self._vectorial_loss( + torch.zeros_like(residual, requires_grad=True), residual + ) return ( self._vect_to_scalar(weights * loss_value), - self._vect_to_scalar(loss_value) + self._vect_to_scalar(loss_value), ) def _vect_to_scalar(self, loss_value): @@ -426,10 +424,11 @@ def _vect_to_scalar(self, loss_value): elif self.loss.reduction == "sum": ret = torch.sum(loss_value) else: - raise RuntimeError(f"Invalid reduction, got {self.loss.reduction} " - "but expected mean or sum.") + raise RuntimeError( + f"Invalid reduction, got {self.loss.reduction} " + "but expected mean or sum." + ) return ret - @property def neural_net(self): @@ -440,7 +439,7 @@ def neural_net(self): :rtype: torch.nn.Module """ return self.models[0] - + @property def weights_dict(self): """ @@ -462,7 +461,7 @@ def scheduler_model(self): :rtype: torch.optim.lr_scheduler._LRScheduler """ return self._scheduler[0] - + @property def scheduler_weights(self): """ @@ -482,7 +481,7 @@ def optimizer_model(self): :rtype: torch.optim.Optimizer """ return self.optimizers[0] - + @property def optimizer_weights(self): """ @@ -491,4 +490,4 @@ def optimizer_weights(self): :return: The optimizer for the mask model. :rtype: torch.optim.Optimizer """ - return self.optimizers[1] \ No newline at end of file + return self.optimizers[1] diff --git a/pina/solvers/rom.py b/pina/solvers/rom.py index 733d76f4..ee4bcff4 100644 --- a/pina/solvers/rom.py +++ b/pina/solvers/rom.py @@ -4,6 +4,7 @@ from pina.solvers import SupervisedSolver + class ReducedOrderModelSolver(SupervisedSolver): r""" ReducedOrderModelSolver solver class. This class implements a @@ -114,10 +115,13 @@ def __init__( rate scheduler. :param dict scheduler_kwargs: LR scheduler constructor keyword args. """ - model = torch.nn.ModuleDict({ - 'reduction_network' : reduction_network, - 'interpolation_network' : interpolation_network}) - + model = torch.nn.ModuleDict( + { + "reduction_network": reduction_network, + "interpolation_network": interpolation_network, + } + ) + super().__init__( model=model, problem=problem, @@ -125,18 +129,22 @@ def __init__( optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, scheduler=scheduler, - scheduler_kwargs=scheduler_kwargs + scheduler_kwargs=scheduler_kwargs, ) # assert reduction object contains encode/ decode - if not hasattr(self.neural_net['reduction_network'], 'encode'): - raise SyntaxError('reduction_network must have encode method. ' - 'The encode method should return a lower ' - 'dimensional representation of the input.') - if not hasattr(self.neural_net['reduction_network'], 'decode'): - raise SyntaxError('reduction_network must have decode method. ' - 'The decode method should return a high ' - 'dimensional representation of the encoding.') + if not hasattr(self.neural_net["reduction_network"], "encode"): + raise SyntaxError( + "reduction_network must have encode method. " + "The encode method should return a lower " + "dimensional representation of the input." + ) + if not hasattr(self.neural_net["reduction_network"], "decode"): + raise SyntaxError( + "reduction_network must have decode method. " + "The decode method should return a high " + "dimensional representation of the encoding." + ) def forward(self, x): """ @@ -149,8 +157,8 @@ def forward(self, x): :return: Solver solution. :rtype: torch.Tensor """ - reduction_network = self.neural_net['reduction_network'] - interpolation_network = self.neural_net['interpolation_network'] + reduction_network = self.neural_net["reduction_network"] + interpolation_network = self.neural_net["interpolation_network"] return reduction_network.decode(interpolation_network(x)) def loss_data(self, input_pts, output_pts): @@ -167,17 +175,18 @@ def loss_data(self, input_pts, output_pts): :rtype: torch.Tensor """ # extract networks - reduction_network = self.neural_net['reduction_network'] - interpolation_network = self.neural_net['interpolation_network'] + reduction_network = self.neural_net["reduction_network"] + interpolation_network = self.neural_net["interpolation_network"] # encoded representations loss encode_repr_inter_net = interpolation_network(input_pts) encode_repr_reduction_network = reduction_network.encode(output_pts) - loss_encode = self.loss(encode_repr_inter_net, - encode_repr_reduction_network) + loss_encode = self.loss( + encode_repr_inter_net, encode_repr_reduction_network + ) # reconstruction loss loss_reconstruction = self.loss( - reduction_network.decode(encode_repr_reduction_network), - output_pts) + reduction_network.decode(encode_repr_reduction_network), output_pts + ) return loss_encode + loss_reconstruction diff --git a/pina/solvers/solver.py b/pina/solvers/solver.py index 729a9d48..ec2f40c8 100644 --- a/pina/solvers/solver.py +++ b/pina/solvers/solver.py @@ -142,13 +142,13 @@ def problem(self): """ The problem formulation.""" return self._pina_problem - + def on_train_start(self): """ On training epoch start this function is call to do global checks for the different solvers. """ - + # 1. Check the verison for dataloader dataloader = self.trainer.train_dataloader if sys.version_info < (3, 8): diff --git a/pina/solvers/supervised.py b/pina/solvers/supervised.py index 28a634b0..42536461 100644 --- a/pina/solvers/supervised.py +++ b/pina/solvers/supervised.py @@ -118,7 +118,7 @@ def training_step(self, batch, batch_idx): :return: The sum of the loss functions. :rtype: LabelTensor """ - + condition_idx = batch["condition"] for condition_id in range(condition_idx.min(), condition_idx.max() + 1): @@ -162,7 +162,7 @@ def loss_data(self, input_pts, output_pts): :rtype: torch.Tensor """ return self.loss(self.forward(input_pts), output_pts) - + @property def scheduler(self): """ diff --git a/pina/trainer.py b/pina/trainer.py index 90779a6e..40f4eb69 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -67,9 +67,9 @@ def _create_or_update_loader(self): pb = self._model.problem if hasattr(pb, "unknown_parameters"): for key in pb.unknown_parameters: - pb.unknown_parameters[key] = torch.nn.Parameter(pb.unknown_parameters[key].data.to(device)) - - + pb.unknown_parameters[key] = torch.nn.Parameter( + pb.unknown_parameters[key].data.to(device) + ) def train(self, **kwargs): """