From 2d5cc090d70528f55dcaac28089dca76f27718fb Mon Sep 17 00:00:00 2001 From: nepslor Date: Tue, 12 Dec 2023 17:49:43 +0100 Subject: [PATCH] corrected n_out bug with optuna --- pyforecaster/forecasting_models/neural_forecasters.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyforecaster/forecasting_models/neural_forecasters.py b/pyforecaster/forecasting_models/neural_forecasters.py index 3b2c949..ef75315 100644 --- a/pyforecaster/forecasting_models/neural_forecasters.py +++ b/pyforecaster/forecasting_models/neural_forecasters.py @@ -197,7 +197,6 @@ def __init__(self, learning_rate: float = 0.01, batch_size: int = None, load_pat if self.load_path is not None: self.load(self.load_path) - self.n_out = self.n_out*2 if self.probabilistic else self.n_out self.model = self.set_arch() self.optimizer = optax.adamw(learning_rate=self.learning_rate) @@ -246,7 +245,7 @@ def init_arch(nn_init, n_inputs_x=1): def set_arch(self): model = FeedForwardModule(n_layers=self.n_layers, n_neurons=self.n_hidden_x, - n_out=self.n_out) + n_out=self.n_out*2 if self.probabilistic else self.n_out) return model def fit(self, inputs, targets, n_epochs=None, savepath_tr_plots=None, stats_step=None, rel_tol=None):