diff --git a/pyforecaster/forecasting_models/neural_forecasters.py b/pyforecaster/forecasting_models/neural_forecasters.py index 3b2c949..ef75315 100644 --- a/pyforecaster/forecasting_models/neural_forecasters.py +++ b/pyforecaster/forecasting_models/neural_forecasters.py @@ -197,7 +197,6 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat if self.load_path is not None: self.load(self.load_path) - self.n_out = self.n_out*2 if self.probabilistic else self.n_out self.model = self.set_arch() self.optimizer = optax.adamw(learning_rate=self.learning_rate) @@ -246,7 +245,7 @@ def init_arch(nn_init, n_inputs_x=1): def set_arch(self): model = FeedForwardModule(n_layers=self.n_layers, n_neurons=self.n_hidden_x, - n_out=self.n_out) + n_out=self.n_out*2 if self.probabilistic else self.n_out) return model def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step=None, rel_tol=None):